diff --git a/appservice/db.go b/appservice/db.go index 9c3280c..34fc046 100644 --- a/appservice/db.go +++ b/appservice/db.go @@ -2,12 +2,14 @@ package appservice import ( "fmt" + "sync" "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/mysql" _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/blake2b" "git.deuxfleurs.fr/Deuxfleurs/easybridge/connector" "git.deuxfleurs.fr/Deuxfleurs/easybridge/mxlib" @@ -84,6 +86,20 @@ type DbPmRoomMap struct { MxRoomID string `gorm:"index:mxroomoid"` } +// ---- Simple locking mechanism + +var dbLocks [256]sync.Mutex + +func dbLockSlot(key string) { + slot := blake2b.Sum512([]byte(key))[0] + dbLocks[slot].Lock() +} + +func dbUnlockSlot(key string) { + slot := blake2b.Sum512([]byte(key))[0] + dbLocks[slot].Unlock() +} + // ---- func dbCacheGet(key string) string { @@ -101,7 +117,9 @@ func dbCachePut(key string, value string) { } func dbCacheTestAndSet(key string, value string) bool { - // TODO make this really an atomic operation + dbLockSlot(key) + defer dbUnlockSlot(key) + // True if value was changed, false if was already set if dbCacheGet(key) != value { dbCachePut(key, value) @@ -111,6 +129,10 @@ func dbCacheTestAndSet(key string, value string) bool { } func dbGetMxRoom(protocol string, roomId connector.RoomID) (string, error) { + slot_key := fmt.Sprintf("room: %s / %s", protocol, roomId) + dbLockSlot(slot_key) + defer dbUnlockSlot(slot_key) + var room DbRoomMap // Check if room exists in our mapping, @@ -148,6 +170,10 @@ func dbGetMxRoom(protocol string, roomId connector.RoomID) (string, error) { } func dbGetMxPmRoom(protocol string, them connector.UserID, themMxId string, usMxId string, usAccount string) (string, error) { + slot_key := fmt.Sprintf("pmroom: %s / %s / %s / %s", protocol, usMxId, usAccount, them) + dbLockSlot(slot_key) + defer dbUnlockSlot(slot_key) + var room DbPmRoomMap must_create := db.First(&room, DbPmRoomMap{ @@ -186,6 +212,10 @@ func dbGetMxPmRoom(protocol string, them connector.UserID, themMxId string, usMx } func dbGetMxUser(protocol string, userId connector.UserID) (string, error) { + slot_key := fmt.Sprintf("user: %s / %s", protocol, userId) + dbLockSlot(slot_key) + defer dbUnlockSlot(slot_key) + var user DbUserMap must_create := db.First(&user, DbUserMap{ diff --git a/connector/mattermost/mattermost.go b/connector/mattermost/mattermost.go index 010e102..51ad636 100644 --- a/connector/mattermost/mattermost.go +++ b/connector/mattermost/mattermost.go @@ -7,6 +7,7 @@ import ( "net/http" _ "os" "strings" + "sync" "time" "github.com/42wim/matterbridge/matterclient" @@ -29,9 +30,15 @@ type Mattermost struct { conn *matterclient.MMClient handlerStopChan chan bool - usermap map[string]string // map username to mm user id - sentjoinedmap map[string]bool // map username/room name to bool - userdisplaynamemap map[UserID]string // map username to last displayname + caches mmCaches +} + +type mmCaches struct { + sync.Mutex + + mmusers map[string]string // map mm username to mm user id + sentjoined map[string]bool // map username/room name to bool + displayname map[UserID]string // map username to last displayname } func (mm *Mattermost) SetHandler(h Handler) { @@ -162,14 +169,20 @@ func (mm *Mattermost) checkUserId(id UserID) (string, error) { if len(x) != 2 || x[1] != mm.server { return "", fmt.Errorf("Invalid user ID: %s", id) } - if user_id, ok := mm.usermap[x[0]]; ok { + + mm.caches.Lock() + defer mm.caches.Unlock() + + if user_id, ok := mm.caches.mmusers[x[0]]; ok { return user_id, nil } + u, resp := mm.conn.Client.GetUserByUsername(x[0], "") if u == nil || resp.Error != nil { return "", fmt.Errorf("Not found: %s (%s)", x[0], resp.Error) } - mm.usermap[x[0]] = u.Id + mm.caches.mmusers[x[0]] = u.Id + return u.Id, nil } @@ -294,24 +307,45 @@ func (mm *Mattermost) Close() { } func (mm *Mattermost) handleConnected() { + // Reinitialize shared data structures mm.handlerStopChan = make(chan bool) - mm.usermap = make(map[string]string) - mm.sentjoinedmap = make(map[string]bool) - mm.userdisplaynamemap = make(map[UserID]string) - go mm.handleLoop(mm.conn.MessageChan, mm.handlerStopChan) + + mm.caches.mmusers = make(map[string]string) + mm.caches.sentjoined = make(map[string]bool) + mm.caches.displayname = make(map[UserID]string) fmt.Printf("Connected to mattermost\n") - chans := mm.conn.GetChannels() - for _, ch := range chans { - if len(strings.Split(ch.Name, "__")) == 2 { - continue // This is a DM channel - } + // Handle incoming messages + go mm.handleLoop(mm.conn.MessageChan, mm.handlerStopChan) + // Initial channel sync + chans := mm.conn.GetChannels() + doneCh := make(map[string]bool) + for _, ch := range chans { + if _, ok := doneCh[ch.Id]; !ok { + doneCh[ch.Id] = true + go mm.initSyncChannel(ch) + } + } +} + +func (mm *Mattermost) initSyncChannel(ch *model.Channel) { + if len(strings.Split(ch.Name, "__")) == 2 { + // DM channel + // Update remote user info + users := strings.Split(ch.Name, "__") + for _, uid := range users { + user := mm.conn.GetUser(uid) + if user != nil && uid != mm.conn.User.Id { + mm.updateUserInfo(user) + } + } + } else { interested, id := mm.reverseRoomId(ch.Id) if !interested { // Skip channels that are not in teams we want to bridge - continue + return } mm.handler.Joined(id) @@ -368,28 +402,29 @@ func (mm *Mattermost) handleConnected() { } else { log.Warnf("Could not get channel members: %s", resp.Error.Error()) } + } - // Read backlog - var backlog *model.PostList - last_seen_post := mm.handler.CacheGet(fmt.Sprintf("last_seen_%s", ch.Id)) - if last_seen_post != "" { - backlog, resp = mm.conn.Client.GetPostsAfter(ch.Id, last_seen_post, 0, 1000, "") - // TODO: if there are more than 1000, loop around - } else { - backlog, resp = mm.conn.Client.GetPostsForChannel(ch.Id, 0, 1000, "") - } - if resp.Error == nil { - for i := 0; i < len(backlog.Order); i++ { - post_id := backlog.Order[len(backlog.Order)-i-1] - post := backlog.Posts[post_id] - post_time := time.Unix(post.CreateAt/1000, 0) - post.Message = fmt.Sprintf("[%s] %s", - post_time.Format("2006-01-02 15:04:05 MST"), post.Message) - mm.handlePost(ch.Name, post, true) - } - } else { - log.Warnf("Could not get channel backlog: %s", resp.Error) + // Read backlog + var backlog *model.PostList + var resp *model.Response + last_seen_post := mm.handler.CacheGet(fmt.Sprintf("last_seen_%s", ch.Id)) + if last_seen_post != "" { + backlog, resp = mm.conn.Client.GetPostsAfter(ch.Id, last_seen_post, 0, 1000, "") + // TODO: if there are more than 1000, loop around + } else { + backlog, resp = mm.conn.Client.GetPostsForChannel(ch.Id, 0, 1000, "") + } + if resp.Error == nil { + for i := 0; i < len(backlog.Order); i++ { + post_id := backlog.Order[len(backlog.Order)-i-1] + post := backlog.Posts[post_id] + post_time := time.Unix(post.CreateAt/1000, 0) + post.Message = fmt.Sprintf("[%s] %s", + post_time.Format("2006-01-02 15:04:05 MST"), post.Message) + mm.handlePost(ch.Name, post, true) } + } else { + log.Warnf("Could not get channel backlog: %s", resp.Error) } } @@ -413,7 +448,10 @@ func (mm *Mattermost) updateUserInfo(user *model.User) { userId := UserID(fmt.Sprintf("%s@%s", user.Username, mm.server)) userDisp := user.GetDisplayName(model.SHOW_NICKNAME_FULLNAME) - if lastdn, ok := mm.userdisplaynamemap[userId]; !ok || lastdn != userDisp { + mm.caches.Lock() + defer mm.caches.Unlock() + + if lastdn, ok := mm.caches.displayname[userId]; !ok || lastdn != userDisp { ui := &UserInfo{ DisplayName: userDisp, } @@ -435,20 +473,24 @@ func (mm *Mattermost) updateUserInfo(user *model.User) { } } mm.handler.UserInfoUpdated(userId, ui) - mm.userdisplaynamemap[userId] = userDisp + mm.caches.displayname[userId] = userDisp } } func (mm *Mattermost) ensureJoined(user *model.User, roomId RoomID) { userId := UserID(fmt.Sprintf("%s@%s", user.Username, mm.server)) cache_key := fmt.Sprintf("%s / %s", userId, roomId) - if _, ok := mm.sentjoinedmap[cache_key]; !ok { + + mm.caches.Lock() + defer mm.caches.Unlock() + + if _, ok := mm.caches.sentjoined[cache_key]; !ok { mm.handler.Event(&Event{ Author: userId, Room: roomId, Type: EVENT_JOIN, }) - mm.sentjoinedmap[cache_key] = true + mm.caches.sentjoined[cache_key] = true } } diff --git a/go.mod b/go.mod index d7fdba8..c0f3f77 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/mattermost/mattermost-server v5.11.1+incompatible github.com/mattn/go-xmpp v0.0.0-20200128155807-a86b6abcb3ad github.com/sirupsen/logrus v1.4.2 + golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect gopkg.in/yaml.v2 v2.2.8 )