Add user-portal mapping to database
This commit is contained in:
parent
666194b066
commit
dce08b1422
6 changed files with 167 additions and 62 deletions
|
@ -42,7 +42,7 @@ func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
|
||||||
receiver = jid
|
receiver = jid
|
||||||
}
|
}
|
||||||
return PortalKey{
|
return PortalKey{
|
||||||
JID: jid,
|
JID: jid,
|
||||||
Receiver: receiver,
|
Receiver: receiver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,3 +152,26 @@ func (portal *Portal) Delete() {
|
||||||
portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
|
portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) GetUserIDs() []types.MatrixUserID {
|
||||||
|
rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
|
||||||
|
WHERE "user".jid=user_portal.user_jid
|
||||||
|
AND user_portal.portal_jid=$1
|
||||||
|
AND user_portal.portal_receiver=$2`,
|
||||||
|
portal.Key.JID, portal.Key.Receiver)
|
||||||
|
if err != nil {
|
||||||
|
portal.log.Debugln("Failed to get portal user ids:", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var userIDs []types.MatrixUserID
|
||||||
|
for rows.Next() {
|
||||||
|
var userID types.MatrixUserID
|
||||||
|
err = rows.Scan(&userID)
|
||||||
|
if err != nil {
|
||||||
|
portal.log.Warnln("Failed to scan row:", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return userIDs
|
||||||
|
}
|
||||||
|
|
19
database/upgrades/2019-05-28-user-portal-table.go
Normal file
19
database/upgrades/2019-05-28-user-portal-table.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package upgrades
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
upgrades[6] = upgrade{"Add user-portal mapping table", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error {
|
||||||
|
_, err := tx.Exec(`CREATE TABLE user_portal (
|
||||||
|
user_jid VARCHAR(255),
|
||||||
|
portal_jid VARCHAR(255),
|
||||||
|
portal_receiver VARCHAR(255),
|
||||||
|
PRIMARY KEY (user_jid, portal_jid, portal_receiver),
|
||||||
|
FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
|
||||||
|
)`)
|
||||||
|
return err
|
||||||
|
}}
|
||||||
|
}
|
|
@ -22,7 +22,7 @@ type upgrade struct {
|
||||||
fn upgradeFunc
|
fn upgradeFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
const NumberOfUpgrades = 6
|
const NumberOfUpgrades = 7
|
||||||
|
|
||||||
var upgrades [NumberOfUpgrades]upgrade
|
var upgrades [NumberOfUpgrades]upgrade
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -165,3 +166,50 @@ func (user *User) Update() {
|
||||||
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
|
user.log.Warnfln("Failed to update %s: %v", user.MXID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) SetPortalKeys(newKeys []PortalKey) error {
|
||||||
|
tx, err := user.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec("DELETE FROM user_portal WHERE user_jid=$1", user.jidPtr())
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
valueStrings := make([]string, len(newKeys))
|
||||||
|
values := make([]interface{}, len(newKeys)*3)
|
||||||
|
for i, key := range newKeys {
|
||||||
|
valueStrings[i] = fmt.Sprintf("($%d, $%d, $%d)", i*3+1, i*3+2, i*3+3)
|
||||||
|
values[i*3] = user.jidPtr()
|
||||||
|
values[i*3+1] = key.JID
|
||||||
|
values[i*3+2] = key.Receiver
|
||||||
|
}
|
||||||
|
query := fmt.Sprintf("INSERT INTO user_portal (user_jid, portal_jid, portal_receiver) VALUES %s",
|
||||||
|
strings.Join(valueStrings, ", "))
|
||||||
|
_, err = tx.Exec(query, values...)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) GetPortalKeys() []PortalKey {
|
||||||
|
rows, err := user.db.Query(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1`, user.jidPtr())
|
||||||
|
if err != nil {
|
||||||
|
user.log.Warnln("Failed to get user portal keys:", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var keys []PortalKey
|
||||||
|
for rows.Next() {
|
||||||
|
var key PortalKey
|
||||||
|
err = rows.Scan(&key.JID, &key.Receiver)
|
||||||
|
if err != nil {
|
||||||
|
user.log.Warnln("Failed to scan row:", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
49
portal.go
49
portal.go
|
@ -50,15 +50,7 @@ func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
|
||||||
defer bridge.portalsLock.Unlock()
|
defer bridge.portalsLock.Unlock()
|
||||||
portal, ok := bridge.portalsByMXID[mxid]
|
portal, ok := bridge.portalsByMXID[mxid]
|
||||||
if !ok {
|
if !ok {
|
||||||
dbPortal := bridge.DB.Portal.GetByMXID(mxid)
|
return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil)
|
||||||
if dbPortal == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
portal = bridge.NewPortal(dbPortal)
|
|
||||||
bridge.portalsByJID[portal.Key] = portal
|
|
||||||
if len(portal.MXID) > 0 {
|
|
||||||
bridge.portalsByMXID[portal.MXID] = portal
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return portal
|
return portal
|
||||||
}
|
}
|
||||||
|
@ -68,17 +60,7 @@ func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
|
||||||
defer bridge.portalsLock.Unlock()
|
defer bridge.portalsLock.Unlock()
|
||||||
portal, ok := bridge.portalsByJID[key]
|
portal, ok := bridge.portalsByJID[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
dbPortal := bridge.DB.Portal.GetByJID(key)
|
return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key)
|
||||||
if dbPortal == nil {
|
|
||||||
dbPortal = bridge.DB.Portal.New()
|
|
||||||
dbPortal.Key = key
|
|
||||||
dbPortal.Insert()
|
|
||||||
}
|
|
||||||
portal = bridge.NewPortal(dbPortal)
|
|
||||||
bridge.portalsByJID[portal.Key] = portal
|
|
||||||
if len(portal.MXID) > 0 {
|
|
||||||
bridge.portalsByMXID[portal.MXID] = portal
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return portal
|
return portal
|
||||||
}
|
}
|
||||||
|
@ -91,17 +73,34 @@ func (bridge *Bridge) GetAllPortals() []*Portal {
|
||||||
for index, dbPortal := range dbPortals {
|
for index, dbPortal := range dbPortals {
|
||||||
portal, ok := bridge.portalsByJID[dbPortal.Key]
|
portal, ok := bridge.portalsByJID[dbPortal.Key]
|
||||||
if !ok {
|
if !ok {
|
||||||
portal = bridge.NewPortal(dbPortal)
|
portal = bridge.loadDBPortal(dbPortal, nil)
|
||||||
bridge.portalsByJID[portal.Key] = portal
|
|
||||||
if len(dbPortal.MXID) > 0 {
|
|
||||||
bridge.portalsByMXID[dbPortal.MXID] = portal
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
output[index] = portal
|
output[index] = portal
|
||||||
}
|
}
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
|
||||||
|
if dbPortal == nil {
|
||||||
|
if key == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dbPortal = bridge.DB.Portal.New()
|
||||||
|
dbPortal.Key = *key
|
||||||
|
dbPortal.Insert()
|
||||||
|
}
|
||||||
|
portal := bridge.NewPortal(dbPortal)
|
||||||
|
bridge.portalsByJID[portal.Key] = portal
|
||||||
|
if len(portal.MXID) > 0 {
|
||||||
|
bridge.portalsByMXID[portal.MXID] = portal
|
||||||
|
}
|
||||||
|
return portal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) GetUsers() []*User {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
|
func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
|
||||||
portal := &Portal{
|
portal := &Portal{
|
||||||
Portal: dbPortal,
|
Portal: dbPortal,
|
||||||
|
|
86
user.go
86
user.go
|
@ -66,20 +66,7 @@ func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
|
||||||
defer bridge.usersLock.Unlock()
|
defer bridge.usersLock.Unlock()
|
||||||
user, ok := bridge.usersByMXID[userID]
|
user, ok := bridge.usersByMXID[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
dbUser := bridge.DB.User.GetByMXID(userID)
|
return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), &userID)
|
||||||
if dbUser == nil {
|
|
||||||
dbUser = bridge.DB.User.New()
|
|
||||||
dbUser.MXID = userID
|
|
||||||
dbUser.Insert()
|
|
||||||
}
|
|
||||||
user = bridge.NewUser(dbUser)
|
|
||||||
bridge.usersByMXID[user.MXID] = user
|
|
||||||
if len(user.JID) > 0 {
|
|
||||||
bridge.usersByJID[user.JID] = user
|
|
||||||
}
|
|
||||||
if len(user.ManagementRoom) > 0 {
|
|
||||||
bridge.managementRooms[user.ManagementRoom] = user
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
@ -89,16 +76,7 @@ func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
|
||||||
defer bridge.usersLock.Unlock()
|
defer bridge.usersLock.Unlock()
|
||||||
user, ok := bridge.usersByJID[userID]
|
user, ok := bridge.usersByJID[userID]
|
||||||
if !ok {
|
if !ok {
|
||||||
dbUser := bridge.DB.User.GetByJID(userID)
|
return bridge.loadDBUser(bridge.DB.User.GetByJID(userID), nil)
|
||||||
if dbUser == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
user = bridge.NewUser(dbUser)
|
|
||||||
bridge.usersByMXID[user.MXID] = user
|
|
||||||
bridge.usersByJID[user.JID] = user
|
|
||||||
if len(user.ManagementRoom) > 0 {
|
|
||||||
bridge.managementRooms[user.ManagementRoom] = user
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
@ -111,20 +89,50 @@ func (bridge *Bridge) GetAllUsers() []*User {
|
||||||
for index, dbUser := range dbUsers {
|
for index, dbUser := range dbUsers {
|
||||||
user, ok := bridge.usersByMXID[dbUser.MXID]
|
user, ok := bridge.usersByMXID[dbUser.MXID]
|
||||||
if !ok {
|
if !ok {
|
||||||
user = bridge.NewUser(dbUser)
|
user = bridge.loadDBUser(dbUser, nil)
|
||||||
bridge.usersByMXID[user.MXID] = user
|
|
||||||
if len(user.JID) > 0 {
|
|
||||||
bridge.usersByJID[user.JID] = user
|
|
||||||
}
|
|
||||||
if len(user.ManagementRoom) > 0 {
|
|
||||||
bridge.managementRooms[user.ManagementRoom] = user
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
output[index] = user
|
output[index] = user
|
||||||
}
|
}
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *types.MatrixUserID) *User {
|
||||||
|
if dbUser == nil {
|
||||||
|
if mxid == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dbUser = bridge.DB.User.New()
|
||||||
|
dbUser.MXID = *mxid
|
||||||
|
dbUser.Insert()
|
||||||
|
}
|
||||||
|
user := bridge.NewUser(dbUser)
|
||||||
|
bridge.usersByMXID[user.MXID] = user
|
||||||
|
if len(user.JID) > 0 {
|
||||||
|
bridge.usersByJID[user.JID] = user
|
||||||
|
}
|
||||||
|
if len(user.ManagementRoom) > 0 {
|
||||||
|
bridge.managementRooms[user.ManagementRoom] = user
|
||||||
|
}
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) GetPortals() []*Portal {
|
||||||
|
keys := user.User.GetPortalKeys()
|
||||||
|
portals := make([]*Portal, len(keys))
|
||||||
|
|
||||||
|
user.bridge.portalsLock.Lock()
|
||||||
|
defer user.bridge.portalsLock.Unlock()
|
||||||
|
|
||||||
|
for i, key := range keys {
|
||||||
|
portal, ok := user.bridge.portalsByJID[key]
|
||||||
|
if !ok {
|
||||||
|
portal = user.bridge.loadDBPortal(user.bridge.DB.Portal.GetByJID(key), &key)
|
||||||
|
}
|
||||||
|
portals[i] = portal
|
||||||
|
}
|
||||||
|
return portals
|
||||||
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) NewUser(dbUser *database.User) *User {
|
func (bridge *Bridge) NewUser(dbUser *database.User) *User {
|
||||||
user := &User{
|
user := &User{
|
||||||
User: dbUser,
|
User: dbUser,
|
||||||
|
@ -295,18 +303,26 @@ func (user *User) PostLogin() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) syncPortals(createAll bool) {
|
func (user *User) syncPortals(createAll bool) {
|
||||||
var chats ChatList
|
chats := make(ChatList, 0, len(user.Conn.Store.Chats))
|
||||||
|
portalKeys := make([]database.PortalKey, 0, len(user.Conn.Store.Chats))
|
||||||
for _, chat := range user.Conn.Store.Chats {
|
for _, chat := range user.Conn.Store.Chats {
|
||||||
ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64)
|
ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime)
|
user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
portal := user.GetPortalByJID(chat.Jid)
|
||||||
|
|
||||||
chats = append(chats, Chat{
|
chats = append(chats, Chat{
|
||||||
Portal: user.GetPortalByJID(chat.Jid),
|
Portal: portal,
|
||||||
Contact: user.Conn.Store.Contacts[chat.Jid],
|
Contact: user.Conn.Store.Contacts[chat.Jid],
|
||||||
LastMessageTime: ts,
|
LastMessageTime: ts,
|
||||||
})
|
})
|
||||||
|
portalKeys = append(portalKeys, portal.Key)
|
||||||
|
}
|
||||||
|
err := user.SetPortalKeys(portalKeys)
|
||||||
|
if err != nil {
|
||||||
|
user.log.Warnln("Failed to update user-portal mapping:", err)
|
||||||
}
|
}
|
||||||
sort.Sort(chats)
|
sort.Sort(chats)
|
||||||
limit := user.bridge.Config.Bridge.InitialChatSync
|
limit := user.bridge.Config.Bridge.InitialChatSync
|
||||||
|
@ -315,7 +331,7 @@ func (user *User) syncPortals(createAll bool) {
|
||||||
}
|
}
|
||||||
now := uint64(time.Now().Unix())
|
now := uint64(time.Now().Unix())
|
||||||
for i, chat := range chats {
|
for i, chat := range chats {
|
||||||
if chat.LastMessageTime + user.bridge.Config.Bridge.SyncChatMaxAge < now {
|
if chat.LastMessageTime+user.bridge.Config.Bridge.SyncChatMaxAge < now {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
|
create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
|
||||||
|
|
Loading…
Reference in a new issue