diff --git a/config/bridge.go b/config/bridge.go index 01c6112..5468d19 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -47,6 +47,9 @@ type BridgeConfig struct { SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"` + InviteOwnPuppetForBackfilling bool `yaml:"invite_own_puppet_for_backfilling"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + CommandPrefix string `yaml:"command_prefix"` Permissions PermissionConfig `yaml:"permissions"` @@ -69,6 +72,9 @@ func (bc *BridgeConfig) setDefaults() { bc.SyncChatMaxAge = 259200 bc.SyncWithCustomPuppets = true + + bc.InviteOwnPuppetForBackfilling = true + bc.PrivateChatPortalMeta = false } type umBridgeConfig BridgeConfig diff --git a/database/portal.go b/database/portal.go index 9772387..77fb162 100644 --- a/database/portal.go +++ b/database/portal.go @@ -66,16 +66,8 @@ func (pq *PortalQuery) New() *Portal { } } -func (pq *PortalQuery) GetAll() (portals []*Portal) { - rows, err := pq.db.Query("SELECT * FROM portal") - if err != nil || rows == nil { - return nil - } - defer rows.Close() - for rows.Next() { - portals = append(portals, pq.New().Scan(rows)) - } - return +func (pq *PortalQuery) GetAll() []*Portal { + return pq.getAll("SELECT * FROM portal") } func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { @@ -86,6 +78,22 @@ func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal { return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid) } +func (pq *PortalQuery) GetAllByJID(jid types.WhatsAppID) []*Portal { + return pq.getAll("SELECT * FROM portal WHERE jid=$1", jid) +} + +func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) { + rows, err := pq.db.Query(query, args...) + if err != nil || rows == nil { + return nil + } + defer rows.Close() + for rows.Next() { + portals = append(portals, pq.New().Scan(rows)) + } + return +} + func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { row := pq.db.QueryRow(query, args...) if row == nil { @@ -101,14 +109,15 @@ type Portal struct { Key PortalKey MXID types.MatrixRoomID - Name string - Topic string - Avatar string + Name string + Topic string + Avatar string + AvatarURL string } func (portal *Portal) Scan(row Scannable) *Portal { - var mxid sql.NullString - err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar) + var mxid, avatarURL sql.NullString + err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL) if err != nil { if err != sql.ErrNoRows { portal.log.Errorln("Database scan failed:", err) @@ -116,6 +125,7 @@ func (portal *Portal) Scan(row Scannable) *Portal { return nil } portal.MXID = mxid.String + portal.AvatarURL = avatarURL.String return portal } @@ -127,8 +137,8 @@ func (portal *Portal) mxidPtr() *string { } func (portal *Portal) Insert() { - _, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)", - portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar) + _, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6, $7)", + portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL) if err != nil { portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) } @@ -139,8 +149,8 @@ func (portal *Portal) Update() { if len(portal.MXID) > 0 { mxid = &portal.MXID } - _, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6", - mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver) + _, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5 WHERE jid=$6 AND receiver=$7", + mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL, portal.Key.JID, portal.Key.Receiver) if err != nil { portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) } diff --git a/database/puppet.go b/database/puppet.go index 66fc5ee..8a9cfae 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -37,7 +37,7 @@ func (pq *PuppetQuery) New() *Puppet { } func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { - rows, err := pq.db.Query("SELECT * FROM puppet") + rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet") if err != nil || rows == nil { return nil } @@ -49,7 +49,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { } func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet { - row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid) + row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE jid=$1", jid) if row == nil { return nil } @@ -57,7 +57,7 @@ func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet { } func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet { - row := pq.db.QueryRow("SELECT * FROM puppet WHERE custom_mxid=$1", mxid) + row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid=$1", mxid) if row == nil { return nil } @@ -65,7 +65,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet { } func (pq *PuppetQuery) GetAllWithCustomMXID() (puppets []*Puppet) { - rows, err := pq.db.Query("SELECT * FROM puppet WHERE custom_mxid<>''") + rows, err := pq.db.Query("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid<>''") if err != nil || rows == nil { return nil } @@ -82,6 +82,7 @@ type Puppet struct { JID types.WhatsAppID Avatar string + AvatarURL string Displayname string NameQuality int8 @@ -91,9 +92,9 @@ type Puppet struct { } func (puppet *Puppet) Scan(row Scannable) *Puppet { - var displayname, avatar, customMXID, accessToken, nextBatch sql.NullString + var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString var quality sql.NullInt64 - err := row.Scan(&puppet.JID, &avatar, &displayname, &quality, &customMXID, &accessToken, &nextBatch) + err := row.Scan(&puppet.JID, &avatar, &avatarURL, &displayname, &quality, &customMXID, &accessToken, &nextBatch) if err != nil { if err != sql.ErrNoRows { puppet.log.Errorln("Database scan failed:", err) @@ -102,6 +103,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet { } puppet.Displayname = displayname.String puppet.Avatar = avatar.String + puppet.AvatarURL = avatarURL.String puppet.NameQuality = int8(quality.Int64) puppet.CustomMXID = customMXID.String puppet.AccessToken = accessToken.String @@ -110,16 +112,16 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet { } func (puppet *Puppet) Insert() { - _, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4, $5, $6, $7)", - puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch) + _, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + puppet.JID, puppet.Avatar, puppet.AvatarURL, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch) if err != nil { puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) } } func (puppet *Puppet) Update() { - _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, custom_mxid=$4, access_token=$5, next_batch=$6 WHERE jid=$7", - puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID) + _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7 WHERE jid=$8", + puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID) if err != nil { puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err) } diff --git a/database/upgrades/2019-06-01-avatar-url-fields.go b/database/upgrades/2019-06-01-avatar-url-fields.go new file mode 100644 index 0000000..ae6cf8b --- /dev/null +++ b/database/upgrades/2019-06-01-avatar-url-fields.go @@ -0,0 +1,19 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(dialect Dialect, tx *sql.Tx, db *sql.DB) error { + _, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`) + if err != nil { + return err + } + _, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`) + if err != nil { + return err + } + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 6f2faa1..54171b1 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -22,7 +22,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 7 +const NumberOfUpgrades = 8 var upgrades [NumberOfUpgrades]upgrade diff --git a/example-config.yaml b/example-config.yaml index 6f01b15..5facea5 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -91,6 +91,16 @@ bridge: # are not normally sent to appservices. sync_with_custom_puppets: true + # Whether or not to invite own WhatsApp user's Matrix puppet into private + # chat portals when backfilling if needed. + # This always uses the default puppet instead of custom puppets due to + # rate limits and timestamp massaging. + invite_own_puppet_for_backfilling: true + # Whether or not to explicitly set the avatar and room name for private + # chat portal rooms. This can be useful if the previous field works fine, + # but causes room avatar/name bugs. + private_chat_portal_meta: false + # The prefix for commands. Only required in non-management rooms. command_prefix: "!wa" diff --git a/portal.go b/portal.go index 4195b22..26732a0 100644 --- a/portal.go +++ b/portal.go @@ -67,9 +67,16 @@ func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal { } func (bridge *Bridge) GetAllPortals() []*Portal { + return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll()) +} + +func (bridge *Bridge) GetAllPortalsByJID(jid types.WhatsAppID) []*Portal { + return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid)) +} + +func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { bridge.portalsLock.Lock() defer bridge.portalsLock.Unlock() - dbPortals := bridge.DB.Portal.GetAll() output := make([]*Portal, len(dbPortals)) for index, dbPortal := range dbPortals { portal, ok := bridge.portalsByJID[dbPortal.Key] @@ -131,8 +138,6 @@ type Portal struct { bridge *Bridge log log.Logger - avatarURL string - roomCreateLock sync.Mutex recentlyHandled [recentlyHandledLength]types.WhatsAppMessageID @@ -333,7 +338,7 @@ func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInf return false } - portal.avatarURL = resp.ContentURI + portal.AvatarURL = resp.ContentURI if len(portal.MXID) > 0 { _, err = portal.MainIntent().SetRoomAvatar(portal.MXID, resp.ContentURI) if err != nil { @@ -582,7 +587,7 @@ func (portal *Portal) beginBackfill() func() { portal.backfilling = true var privateChatPuppetInvited bool var privateChatPuppet *Puppet - if portal.IsPrivateChat() { + if portal.IsPrivateChat() && portal.bridge.Config.Bridge.InviteOwnPuppetForBackfilling { privateChatPuppet = portal.bridge.GetPuppetByJID(portal.Key.Receiver) portal.privateChatBackfillInvitePuppet = func() { if privateChatPuppetInvited { @@ -686,7 +691,14 @@ func (portal *Portal) CreateMatrixRoom(user *User) error { var metadata *whatsappExt.GroupInfo isPrivateChat := false if portal.IsPrivateChat() { - portal.Name = "" + puppet := portal.bridge.GetPuppetByJID(portal.Key.JID) + if portal.bridge.Config.Bridge.PrivateChatPortalMeta { + portal.Name = puppet.Displayname + portal.AvatarURL = puppet.AvatarURL + portal.Avatar = puppet.Avatar + } else { + portal.Name = "" + } portal.Topic = "WhatsApp private chat" isPrivateChat = true } else if portal.IsStatusBroadcastRoom() { @@ -708,11 +720,11 @@ func (portal *Portal) CreateMatrixRoom(user *User) error { PowerLevels: portal.GetBasePowerLevels(), }, }} - if len(portal.avatarURL) > 0 { + if len(portal.AvatarURL) > 0 { initialState = append(initialState, &mautrix.Event{ Type: mautrix.StateRoomAvatar, Content: mautrix.Content{ - URL: portal.avatarURL, + URL: portal.AvatarURL, }, }) } diff --git a/puppet.go b/puppet.go index 0f41dc9..bc96eba 100644 --- a/puppet.go +++ b/puppet.go @@ -193,7 +193,9 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI if err != nil { puppet.log.Warnln("Failed to remove avatar:", err) } + puppet.AvatarURL = "" puppet.Avatar = avatar.Tag + go puppet.updatePortalAvatar() return true } @@ -210,14 +212,68 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI return false } - err = puppet.DefaultIntent().SetAvatarURL(resp.ContentURI) + puppet.AvatarURL = resp.ContentURI + err = puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) if err != nil { puppet.log.Warnln("Failed to set avatar:", err) } puppet.Avatar = avatar.Tag + go puppet.updatePortalAvatar() return true } +func (puppet *Puppet) UpdateName(source *User, contact whatsapp.Contact) bool { + newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(contact) + if puppet.Displayname != newName && quality >= puppet.NameQuality { + err := puppet.DefaultIntent().SetDisplayName(newName) + if err == nil { + puppet.Displayname = newName + puppet.NameQuality = quality + go puppet.updatePortalName() + puppet.Update() + } else { + puppet.log.Warnln("Failed to set display name:", err) + } + return true + } + return false +} + +func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) { + if puppet.bridge.Config.Bridge.PrivateChatPortalMeta { + for _, portal := range puppet.bridge.GetAllPortalsByJID(puppet.JID) { + meta(portal) + } + } +} + +func (puppet *Puppet) updatePortalAvatar() { + puppet.updatePortalMeta(func(portal *Portal) { + if len(portal.MXID) > 0 { + _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL) + if err != nil { + portal.log.Warnln("Failed to set avatar:", err) + } + } + portal.AvatarURL = puppet.AvatarURL + portal.Avatar = puppet.Avatar + portal.Update() + }) +} + +func (puppet *Puppet) updatePortalName() { + puppet.updatePortalMeta(func(portal *Portal) { + if len(portal.MXID) > 0 { + _, err := portal.MainIntent().SetRoomName(portal.MXID, puppet.Displayname) + if err != nil { + portal.log.Warnln("Failed to set name:", err) + } + } + portal.Name = puppet.Displayname + portal.Update() + }) +} + func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) { err := puppet.DefaultIntent().EnsureRegistered() if err != nil { @@ -227,19 +283,11 @@ func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) { if contact.Jid == source.JID { contact.Notify = source.Conn.Info.Pushname } - newName, quality := puppet.bridge.Config.Bridge.FormatDisplayname(contact) - if puppet.Displayname != newName && quality >= puppet.NameQuality { - err := puppet.DefaultIntent().SetDisplayName(newName) - if err == nil { - puppet.Displayname = newName - puppet.NameQuality = quality - puppet.Update() - } else { - puppet.log.Warnln("Failed to set display name:", err) - } - } - if puppet.UpdateAvatar(source, nil) { + update := false + update = puppet.UpdateName(source, contact) || update + update = puppet.UpdateAvatar(source, nil) || update + if update { puppet.Update() } }