Room autorejoin

This commit is contained in:
Alex 2020-02-26 16:30:10 +01:00
parent 67c7f7361d
commit f3f1b8d981
5 changed files with 87 additions and 31 deletions

View file

@ -3,6 +3,7 @@ package main
import (
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
@ -18,17 +19,23 @@ type Account struct {
JoinedRooms map[RoomID]bool
}
var accountsLock sync.Mutex
var registeredAccounts = map[string]map[string]*Account{}
func AddAccount(a *Account) {
accountsLock.Lock()
defer accountsLock.Unlock()
if _, ok := registeredAccounts[a.MatrixUser]; !ok {
registeredAccounts[a.MatrixUser] = make(map[string]*Account)
}
registeredAccounts[a.MatrixUser][a.AccountName] = a
ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol)
}
func FindAccount(mxUser string, name string) *Account {
accountsLock.Lock()
defer accountsLock.Unlock()
if u, ok := registeredAccounts[mxUser]; ok {
if a, ok := u[name]; ok {
return a
@ -38,6 +45,9 @@ func FindAccount(mxUser string, name string) *Account {
}
func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account {
accountsLock.Lock()
defer accountsLock.Unlock()
if u, ok := registeredAccounts[mxUser]; ok {
for _, acct := range u {
if acct.Protocol == protocol {
@ -51,17 +61,55 @@ func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account {
}
func RemoveAccount(mxUser string, name string) {
accountsLock.Lock()
defer accountsLock.Unlock()
if u, ok := registeredAccounts[mxUser]; ok {
delete(u, name)
}
}
// ----
func (a *Account) ezbrMessagef(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
msg = fmt.Sprintf("%s: %s", a.Protocol, msg)
ezbrSystemSend(a.MatrixUser, msg)
}
func (a *Account) connect(config map[string]string, join_rooms []string) {
ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol)
err := a.Conn.Configure(config)
if err != nil {
ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot connect: %s", a.AccountName, a.Protocol, err.Error())
return
}
for _, room := range join_rooms {
var entry DbJoinedRoom
db.Where(&DbJoinedRoom{
MxUserID: a.MatrixUser,
Protocol: a.Protocol,
AccountName: a.AccountName,
RoomID: RoomID(room),
}).FirstOrCreate(&entry)
}
var autojoin []DbJoinedRoom
db.Where(&DbJoinedRoom{
MxUserID: a.MatrixUser,
Protocol: a.Protocol,
AccountName: a.AccountName,
}).Find(&autojoin)
for _, aj := range autojoin {
err := a.Conn.Join(aj.RoomID)
if err != nil {
ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot join %s: %s", a.AccountName, a.Protocol, aj.RoomID, err.Error())
}
}
}
// ---- Begin event handlers ----
func (a *Account) Joined(roomId RoomID) {
@ -69,6 +117,14 @@ func (a *Account) Joined(roomId RoomID) {
if err != nil {
a.ezbrMessagef("Dropping Account.Joined %s: %s", roomId, err.Error())
}
var entry DbJoinedRoom
db.Where(&DbJoinedRoom{
MxUserID: a.MatrixUser,
Protocol: a.Protocol,
AccountName: a.AccountName,
RoomID: roomId,
}).FirstOrCreate(&entry)
}
func (a *Account) joinedInternal(roomId RoomID) error {
@ -95,6 +151,13 @@ func (a *Account) Left(roomId RoomID) {
if err != nil {
a.ezbrMessagef("Dropping Account.Left %s: %s", roomId, err.Error())
}
db.Where(&DbJoinedRoom{
MxUserID: a.MatrixUser,
Protocol: a.Protocol,
AccountName: a.AccountName,
RoomID: roomId,
}).Delete(&DbJoinedRoom{})
}
func (a *Account) leftInternal(roomId RoomID) error {

View file

@ -81,7 +81,8 @@ func (xm *XMPP) Configure(c Configuration) error {
return fmt.Errorf("JID %s not on server %s", xm.jid, xm.server)
}
xm.jid_localpart = jid_parts[0]
xm.nickname = xm.jid_localpart
xm.nickname = c.GetString("nickname", xm.jid_locakpart)
xm.password, err = c.GetString("password")
if err != nil {

16
db.go
View file

@ -36,6 +36,9 @@ func InitDb() error {
db.AutoMigrate(&DbPmRoomMap{})
db.Model(&DbPmRoomMap{}).AddIndex("idx_protocol_user_account_user", "protocol", "user_id", "mx_user_id", "account_name")
db.AutoMigrate(&DbJoinedRoom{})
db.Model(&DbJoinedRoom{}).AddIndex("idx_user_protocol_account", "mx_user_id", "protocol", "account_name")
return nil
}
@ -86,6 +89,19 @@ type DbPmRoomMap struct {
MxRoomID string `gorm:"index:mxroomoid"`
}
// List of joined channels to be re-joined on reconnect
type DbJoinedRoom struct {
gorm.Model
// User id and account name
MxUserID string
Protocol string
AccountName string
// Room ID
RoomID connector.RoomID
}
// ---- Simple locking mechanism
var dbLocks [256]sync.Mutex

22
main.go
View file

@ -29,7 +29,7 @@ type ConfigAccount struct {
}
type ConfigFile struct {
HttpBindAddr string `json:"http_bind_addr"`
ASBindAddr string `json:"appservice_bind_addr"`
Registration string `json:"registration"`
Server string `json:"homeserver_url"`
DbType string `json:"db_type"`
@ -45,7 +45,7 @@ var registration *mxlib.Registration
func readConfig() ConfigFile {
config_file := ConfigFile{
HttpBindAddr: "0.0.0.0:8321",
ASBindAddr: "0.0.0.0:8321",
Registration: "./registration.yaml",
Server: "http://localhost:8008",
DbType: "sqlite3",
@ -192,7 +192,7 @@ func main() {
}
conn.SetHandler(account)
AddAccount(account)
go connectAndJoin(account, params)
go account.connect(params.Config, params.Rooms)
}
}
@ -201,19 +201,3 @@ func main() {
log.Fatal(err)
}
}
func connectAndJoin(account *Account, params ConfigAccount) {
log.Printf("Connecting to %s", params.Protocol)
err := account.Conn.Configure(params.Config)
if err != nil {
log.Printf("Could not connect to %s: %s", params.Protocol, err)
} else {
log.Printf("Connected to %s, now joining %#v", params.Protocol, params.Rooms)
for _, room := range params.Rooms {
err := account.Conn.Join(connector.RoomID(room))
if err != nil {
log.Printf("Could not join %s: %s", room, err)
}
}
}
}

View file

@ -13,14 +13,6 @@ import (
"git.deuxfleurs.fr/Deuxfleurs/easybridge/mxlib"
)
type Config struct {
HttpBindAddr string
Server string
DbType string
DbPath string
MatrixDomain string
}
var mx *mxlib.Client
func StartAppService() (chan error, error) {
@ -55,8 +47,8 @@ func StartAppService() (chan error, error) {
errch := make(chan error)
go func() {
log.Printf("Starting HTTP server on %s", config.HttpBindAddr)
err := http.ListenAndServe(config.HttpBindAddr, checkTokenAndLog(router))
log.Printf("Starting HTTP server on %s", config.ASBindAddr)
err := http.ListenAndServe(config.ASBindAddr, checkTokenAndLog(router))
if err != nil {
errch <- err
}