diff --git a/database/portal.go b/database/portal.go index 25e1dc9..9772387 100644 --- a/database/portal.go +++ b/database/portal.go @@ -42,7 +42,7 @@ func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey { receiver = jid } return PortalKey{ - JID: jid, + JID: jid, Receiver: receiver, } } @@ -152,3 +152,26 @@ func (portal *Portal) Delete() { portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err) } } + +func (portal *Portal) GetUserIDs() []types.MatrixUserID { + rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal + WHERE "user".jid=user_portal.user_jid + AND user_portal.portal_jid=$1 + AND user_portal.portal_receiver=$2`, + portal.Key.JID, portal.Key.Receiver) + if err != nil { + portal.log.Debugln("Failed to get portal user ids:", err) + return nil + } + var userIDs []types.MatrixUserID + for rows.Next() { + var userID types.MatrixUserID + err = rows.Scan(&userID) + if err != nil { + portal.log.Warnln("Failed to scan row:", err) + continue + } + userIDs = append(userIDs, userID) + } + return userIDs +} diff --git a/database/upgrades/2019-05-28-user-portal-table.go b/database/upgrades/2019-05-28-user-portal-table.go new file mode 100644 index 0000000..cbc44a7 --- /dev/null +++ b/database/upgrades/2019-05-28-user-portal-table.go @@ -0,0 +1,19 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[6] = upgrade{"Add user-portal mapping table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { + _, err := tx.Exec(`CREATE TABLE user_portal ( + user_jid VARCHAR(255), + portal_jid VARCHAR(255), + portal_receiver VARCHAR(255), + PRIMARY KEY (user_jid, portal_jid, portal_receiver), + FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE + )`) + return err + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index fe96e23..6f2faa1 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -22,7 +22,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 6 +const NumberOfUpgrades = 7 var upgrades [NumberOfUpgrades]upgrade diff --git a/database/user.go b/database/user.go index 4b73e62..300d12e 100644 --- a/database/user.go +++ b/database/user.go @@ -18,6 +18,7 @@ package database import ( "database/sql" + "fmt" "strings" "time" @@ -165,3 +166,50 @@ func (user *User) Update() { user.log.Warnfln("Failed to update %s: %v", user.MXID, err) } } + +func (user *User) SetPortalKeys(newKeys []PortalKey) error { + tx, err := user.db.Begin() + if err != nil { + return err + } + _, err = tx.Exec("DELETE FROM user_portal WHERE user_jid=$1", user.jidPtr()) + if err != nil { + _ = tx.Rollback() + return err + } + valueStrings := make([]string, len(newKeys)) + values := make([]interface{}, len(newKeys)*3) + for i, key := range newKeys { + valueStrings[i] = fmt.Sprintf("($%d, $%d, $%d)", i*3+1, i*3+2, i*3+3) + values[i*3] = user.jidPtr() + values[i*3+1] = key.JID + values[i*3+2] = key.Receiver + } + query := fmt.Sprintf("INSERT INTO user_portal (user_jid, portal_jid, portal_receiver) VALUES %s", + strings.Join(valueStrings, ", ")) + _, err = tx.Exec(query, values...) + if err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func (user *User) GetPortalKeys() []PortalKey { + rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr()) + if err != nil { + user.log.Warnln("Failed to get user portal keys:", err) + return nil + } + var keys []PortalKey + for rows.Next() { + var key PortalKey + err = rows.Scan(&key.JID, &key.Receiver) + if err != nil { + user.log.Warnln("Failed to scan row:", err) + continue + } + keys = append(keys, key) + } + return keys +} diff --git a/portal.go b/portal.go index 36263c1..cdc479f 100644 --- a/portal.go +++ b/portal.go @@ -50,15 +50,7 @@ func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal { defer bridge.portalsLock.Unlock() portal, ok := bridge.portalsByMXID[mxid] if !ok { - dbPortal := bridge.DB.Portal.GetByMXID(mxid) - if dbPortal == nil { - return nil - } - portal = bridge.NewPortal(dbPortal) - bridge.portalsByJID[portal.Key] = portal - if len(portal.MXID) > 0 { - bridge.portalsByMXID[portal.MXID] = portal - } + return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil) } return portal } @@ -68,17 +60,7 @@ func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal { defer bridge.portalsLock.Unlock() portal, ok := bridge.portalsByJID[key] if !ok { - dbPortal := bridge.DB.Portal.GetByJID(key) - if dbPortal == nil { - dbPortal = bridge.DB.Portal.New() - dbPortal.Key = key - dbPortal.Insert() - } - portal = bridge.NewPortal(dbPortal) - bridge.portalsByJID[portal.Key] = portal - if len(portal.MXID) > 0 { - bridge.portalsByMXID[portal.MXID] = portal - } + return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key) } return portal } @@ -91,17 +73,34 @@ func (bridge *Bridge) GetAllPortals() []*Portal { for index, dbPortal := range dbPortals { portal, ok := bridge.portalsByJID[dbPortal.Key] if !ok { - portal = bridge.NewPortal(dbPortal) - bridge.portalsByJID[portal.Key] = portal - if len(dbPortal.MXID) > 0 { - bridge.portalsByMXID[dbPortal.MXID] = portal - } + portal = bridge.loadDBPortal(dbPortal, nil) } output[index] = portal } return output } +func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { + if dbPortal == nil { + if key == nil { + return nil + } + dbPortal = bridge.DB.Portal.New() + dbPortal.Key = *key + dbPortal.Insert() + } + portal := bridge.NewPortal(dbPortal) + bridge.portalsByJID[portal.Key] = portal + if len(portal.MXID) > 0 { + bridge.portalsByMXID[portal.MXID] = portal + } + return portal +} + +func (portal *Portal) GetUsers() []*User { + return nil +} + func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal { portal := &Portal{ Portal: dbPortal, diff --git a/user.go b/user.go index 8713c3a..ec85b41 100644 --- a/user.go +++ b/user.go @@ -66,20 +66,7 @@ func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User { defer bridge.usersLock.Unlock() user, ok := bridge.usersByMXID[userID] if !ok { - dbUser := bridge.DB.User.GetByMXID(userID) - if dbUser == nil { - dbUser = bridge.DB.User.New() - dbUser.MXID = userID - dbUser.Insert() - } - user = bridge.NewUser(dbUser) - bridge.usersByMXID[user.MXID] = user - if len(user.JID) > 0 { - bridge.usersByJID[user.JID] = user - } - if len(user.ManagementRoom) > 0 { - bridge.managementRooms[user.ManagementRoom] = user - } + return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), &userID) } return user } @@ -89,16 +76,7 @@ func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User { defer bridge.usersLock.Unlock() user, ok := bridge.usersByJID[userID] if !ok { - dbUser := bridge.DB.User.GetByJID(userID) - if dbUser == nil { - return nil - } - user = bridge.NewUser(dbUser) - bridge.usersByMXID[user.MXID] = user - bridge.usersByJID[user.JID] = user - if len(user.ManagementRoom) > 0 { - bridge.managementRooms[user.ManagementRoom] = user - } + return bridge.loadDBUser(bridge.DB.User.GetByJID(userID), nil) } return user } @@ -111,20 +89,50 @@ func (bridge *Bridge) GetAllUsers() []*User { for index, dbUser := range dbUsers { user, ok := bridge.usersByMXID[dbUser.MXID] if !ok { - user = bridge.NewUser(dbUser) - bridge.usersByMXID[user.MXID] = user - if len(user.JID) > 0 { - bridge.usersByJID[user.JID] = user - } - if len(user.ManagementRoom) > 0 { - bridge.managementRooms[user.ManagementRoom] = user - } + user = bridge.loadDBUser(dbUser, nil) } output[index] = user } return output } +func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *types.MatrixUserID) *User { + if dbUser == nil { + if mxid == nil { + return nil + } + dbUser = bridge.DB.User.New() + dbUser.MXID = *mxid + dbUser.Insert() + } + user := bridge.NewUser(dbUser) + bridge.usersByMXID[user.MXID] = user + if len(user.JID) > 0 { + bridge.usersByJID[user.JID] = user + } + if len(user.ManagementRoom) > 0 { + bridge.managementRooms[user.ManagementRoom] = user + } + return user +} + +func (user *User) GetPortals() []*Portal { + keys := user.User.GetPortalKeys() + portals := make([]*Portal, len(keys)) + + user.bridge.portalsLock.Lock() + defer user.bridge.portalsLock.Unlock() + + for i, key := range keys { + portal, ok := user.bridge.portalsByJID[key] + if !ok { + portal = user.bridge.loadDBPortal(user.bridge.DB.Portal.GetByJID(key), &key) + } + portals[i] = portal + } + return portals +} + func (bridge *Bridge) NewUser(dbUser *database.User) *User { user := &User{ User: dbUser, @@ -295,18 +303,26 @@ func (user *User) PostLogin() { } func (user *User) syncPortals(createAll bool) { - var chats ChatList + chats := make(ChatList, 0, len(user.Conn.Store.Chats)) + portalKeys := make([]database.PortalKey, 0, len(user.Conn.Store.Chats)) for _, chat := range user.Conn.Store.Chats { ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64) if err != nil { user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime) continue } + portal := user.GetPortalByJID(chat.Jid) + chats = append(chats, Chat{ - Portal: user.GetPortalByJID(chat.Jid), + Portal: portal, Contact: user.Conn.Store.Contacts[chat.Jid], LastMessageTime: ts, }) + portalKeys = append(portalKeys, portal.Key) + } + err := user.SetPortalKeys(portalKeys) + if err != nil { + user.log.Warnln("Failed to update user-portal mapping:", err) } sort.Sort(chats) limit := user.bridge.Config.Bridge.InitialChatSync @@ -315,7 +331,7 @@ func (user *User) syncPortals(createAll bool) { } now := uint64(time.Now().Unix()) for i, chat := range chats { - if chat.LastMessageTime + user.bridge.Config.Bridge.SyncChatMaxAge < now { + if chat.LastMessageTime+user.bridge.Config.Bridge.SyncChatMaxAge < now { break } create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit