Possibly significantly improve how portals are created and synced

This commit is contained in:
Tulir Asokan 2019-05-22 16:46:18 +03:00
parent 6f2a51410f
commit b363547bdf
10 changed files with 248 additions and 33 deletions

View file

@ -208,6 +208,7 @@ func (handler *CommandHandler) CommandReconnect(ce *CommandEvent) {
ce.User.Connected = true
ce.User.ConnectionErrors = 0
ce.Reply("Reconnected successfully.")
go ce.User.PostLogin()
}
const cmdDisconnectHelp = `disconnect - Disconnect from WhatsApp (without logging out)`

View file

@ -37,6 +37,11 @@ type BridgeConfig struct {
MaxConnectionAttempts int `yaml:"max_connection_attempts"`
ReportConnectionRetry bool `yaml:"report_connection_retry"`
InitialChatSync int `yaml:"initial_chat_sync_count"`
InitialHistoryFill int `yaml:"initial_history_fill_count"`
RecoverChatSync int `yaml:"recovery_chat_sync_count"`
RecoverHistory bool `yaml:"recovery_history_backfill"`
CommandPrefix string `yaml:"command_prefix"`
Permissions PermissionConfig `yaml:"permissions"`
@ -49,6 +54,11 @@ func (bc *BridgeConfig) setDefaults() {
bc.ConnectionTimeout = 20
bc.MaxConnectionAttempts = 3
bc.ReportConnectionRetry = true
bc.InitialChatSync = 10
bc.InitialHistoryFill = 20
bc.RecoverChatSync = -1
bc.RecoverHistory = true
}
type umBridgeConfig BridgeConfig

View file

@ -53,11 +53,23 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
}
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
}
func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
"FROM message WHERE mxid=$1", mxid)
}
func (mq *MessageQuery) GetLastInChat(chat PortalKey) *Message {
msg := mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 ORDER BY timestamp DESC LIMIT 1", chat.JID, chat.Receiver)
if msg.Timestamp == 0 {
// Old db, we don't know what the last message is.
return nil
}
return msg
}
func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@ -72,16 +84,17 @@ type Message struct {
db *Database
log log.Logger
Chat PortalKey
JID types.WhatsAppMessageID
MXID types.MatrixEventID
Sender types.WhatsAppID
Content *waProto.Message
Chat PortalKey
JID types.WhatsAppMessageID
MXID types.MatrixEventID
Sender types.WhatsAppID
Timestamp uint64
Content *waProto.Message
}
func (msg *Message) Scan(row Scannable) *Message {
var content []byte
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &content)
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &msg.Timestamp, &content)
if err != nil {
if err != sql.ErrNoRows {
msg.log.Errorln("Database scan failed:", err)

View file

@ -0,0 +1,15 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[2] = upgrade{"Add timestamp column to messages", func(dialect Dialect, tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
if err != nil {
return err
}
return nil
}}
}

View file

@ -0,0 +1,15 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[3] = upgrade{"Add last_connection column to users", func(dialect Dialect, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -22,7 +22,7 @@ type upgrade struct {
fn upgradeFunc
}
var upgrades [2]upgrade
var upgrades [4]upgrade
func getVersion(dialect Dialect, db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
@ -65,7 +65,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
log.Infofln("Database currently on v%d, latest: v%d", version, len(upgrades))
for i, upgrade := range upgrades[version:] {
log.Infofln("Upgrading database to v%d: %s", i+1, upgrade.message)
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgrade.message)
tx, err := db.Begin()
if err != nil {
return err
@ -74,7 +74,7 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error {
if err != nil {
return err
}
err = setVersion(dialect, tx, i+1)
err = setVersion(dialect, tx, version+i+1)
if err != nil {
return err
}

View file

@ -19,6 +19,7 @@ package database
import (
"database/sql"
"strings"
"time"
"github.com/Rhymen/go-whatsapp"
@ -76,12 +77,14 @@ type User struct {
JID types.WhatsAppID
ManagementRoom types.MatrixRoomID
Session *whatsapp.Session
LastConnection uint64
}
func (user *User) Scan(row Scannable) *User {
var jid, clientID, clientToken, serverToken sql.NullString
var encKey, macKey []byte
err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &clientID, &clientToken, &serverToken, &encKey, &macKey)
err := row.Scan(&user.MXID, &jid, &user.ManagementRoom, &clientID, &clientToken, &serverToken, &encKey, &macKey,
&user.LastConnection)
if err != nil {
if err != sql.ErrNoRows {
user.log.Errorln("Database scan failed:", err)
@ -134,18 +137,28 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
func (user *User) Insert() {
sess := user.sessionUnptr()
_, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
user.ManagementRoom,
_, err := user.db.Exec(`INSERT INTO "user" (mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
user.MXID, user.jidPtr(),
user.ManagementRoom, user.LastConnection,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
if err != nil {
user.log.Warnfln("Failed to insert %s: %v", user.MXID, err)
}
}
func (user *User) UpdateLastConnection() {
user.LastConnection = uint64(time.Now().Unix())
_, err := user.db.Exec(`UPDATE "user" SET last_connection=$1 WHERE mxid=$2`,
user.LastConnection, user.MXID)
if err != nil {
user.log.Warnfln("Failed to update last connection ts: %v", err)
}
}
func (user *User) Update() {
sess := user.sessionUnptr()
_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
user.jidPtr(), user.ManagementRoom,
_, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, last_connection=$3, client_id=$4, client_token=$5, server_token=$6, enc_key=$7, mac_key=$8 WHERE mxid=$9`,
user.jidPtr(), user.ManagementRoom, user.LastConnection,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
user.MXID)
if err != nil {

View file

@ -66,6 +66,16 @@ bridge:
# If false, it will only report when it stops retrying.
report_connection_retry: true
# Number of chats to sync for new users.
initial_chat_sync_count: 10
# Number of old messages to fill when creating new portal rooms.
initial_history_fill_count: 20
# Maximum number of chats to sync when recovering from downtime.
# Set to -1 to sync all new chats during downtime.
recovery_chat_sync_limit: -1
# Whether or not to sync history when recovering from downtime.
recovery_history_backfill: true
# The prefix for commands. Only required in non-management rooms.
command_prefix: "!wa"

View file

@ -30,8 +30,10 @@ import (
"net/http"
"strings"
"sync"
"time"
"github.com/Rhymen/go-whatsapp"
"github.com/Rhymen/go-whatsapp/binary"
waProto "github.com/Rhymen/go-whatsapp/binary/proto"
log "maunium.net/go/maulogger/v2"
@ -119,8 +121,9 @@ func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
const recentlyHandledLength = 100
type PortalMessage struct {
source *User
data interface{}
source *User
data interface{}
timestamp uint64
}
type Portal struct {
@ -137,6 +140,7 @@ type Portal struct {
recentlyHandledLock sync.Mutex
recentlyHandledIndex uint8
backfillLock sync.Mutex
lastMessageTs uint64
messages chan PortalMessage
@ -144,11 +148,13 @@ type Portal struct {
isPrivate *bool
}
const MaxMessageAgeToCreatePortal = 5 * 60 // 5 minutes
func (portal *Portal) handleMessageLoop() {
for msg := range portal.messages {
if len(portal.MXID) == 0 {
_, isRevocation := msg.data.(whatsappExt.MessageRevocation)
if isRevocation {
if msg.timestamp+MaxMessageAgeToCreatePortal < uint64(time.Now().Unix()) {
portal.log.Debugln("Not creating portal room for incoming message as the message is too old.")
continue
}
err := portal.CreateMatrixRoom(msg.source)
@ -221,6 +227,7 @@ func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo,
msg.Chat = portal.Key
msg.JID = message.GetKey().GetId()
msg.MXID = mxid
msg.Timestamp = message.GetMessageTimestamp()
if message.GetKey().GetFromMe() {
msg.Sender = source.JID
} else if portal.IsPrivateChat() {
@ -414,6 +421,7 @@ func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
if portal.IsPrivateChat() {
return
}
portal.log.Infoln("Syncing portal for", user.MXID)
if len(portal.MXID) == 0 {
portal.Name = contact.Name
@ -524,15 +532,52 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
}
}
func (portal *Portal) FillHistory(user *User) error {
resp, err := user.Conn.LoadMessages(portal.Key.JID, "", 50)
func (portal *Portal) BackfillHistory(user *User) error {
if !portal.bridge.Config.Bridge.RecoverHistory {
return nil
}
portal.backfillLock.Lock()
defer portal.backfillLock.Unlock()
lastMessage := portal.bridge.DB.Message.GetLastInChat(portal.Key)
if lastMessage == nil {
return nil
}
lastMessageID := lastMessage.JID
portal.log.Infoln("Backfilling history since", lastMessageID, "for", user.MXID)
for len(lastMessageID) > 0 {
portal.log.Debugln("Backfilling history: 50 messages after", lastMessageID)
resp, err := user.Conn.LoadMessagesAfter(portal.Key.JID, lastMessageID, 50)
if err != nil {
return err
}
lastMessageID, err = portal.handleHistory(user, resp)
if err != nil {
return err
}
}
portal.log.Infoln("Backfilling finished")
return nil
}
func (portal *Portal) FillInitialHistory(user *User) error {
if portal.bridge.Config.Bridge.InitialHistoryFill == 0 {
return nil
}
resp, err := user.Conn.LoadMessages(portal.Key.JID, "", portal.bridge.Config.Bridge.InitialHistoryFill)
if err != nil {
return err
}
messages, ok := resp.Content.([]interface{})
_, err = portal.handleHistory(user, resp)
return err
}
func (portal *Portal) handleHistory(user *User, history *binary.Node) (string, error) {
messages, ok := history.Content.([]interface{})
if !ok {
return fmt.Errorf("history response not list")
return "", fmt.Errorf("history response not a list")
}
lastID := ""
for _, rawMessage := range messages {
message, ok := rawMessage.(*waProto.WebMessageInfo)
if !ok {
@ -541,8 +586,9 @@ func (portal *Portal) FillHistory(user *User) error {
}
fmt.Println("Filling history", message.GetKey(), message.GetMessageTimestamp())
portal.handleMessage(PortalMessage{user, whatsapp.ParseProtoMessage(message)})
lastID = message.GetKey().GetId()
}
return nil
return lastID, nil
}
func (portal *Portal) CreateMatrixRoom(user *User) error {
@ -557,6 +603,8 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
return err
}
portal.log.Infoln("Creating Matrix room. Info source:", user.MXID)
isPrivateChat := false
if portal.IsPrivateChat() {
portal.Name = ""
@ -592,7 +640,7 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
}
portal.MXID = resp.RoomID
portal.Update()
err = portal.FillHistory(user)
err = portal.FillInitialHistory(user)
if err != nil {
portal.log.Errorln("Failed to fill history:", err)
}

104
user.go
View file

@ -19,6 +19,8 @@ package main
import (
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"time"
@ -29,6 +31,7 @@ import (
"maunium.net/go/mautrix/format"
"github.com/Rhymen/go-whatsapp"
waProto "github.com/Rhymen/go-whatsapp/binary/proto"
"maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types"
@ -142,6 +145,9 @@ func (user *User) SetManagementRoom(roomID types.MatrixRoomID) {
func (user *User) SetSession(session *whatsapp.Session) {
user.Session = session
if session == nil {
user.LastConnection = 0
}
user.Update()
}
@ -188,6 +194,7 @@ func (user *User) RestoreSession() bool {
user.ConnectionErrors = 0
user.SetSession(&sess)
user.log.Debugln("Session restored successfully")
go user.PostLogin()
}
return true
}
@ -243,7 +250,84 @@ func (user *User) Login(ce *CommandEvent) {
user.ConnectionErrors = 0
user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
user.SetSession(&session)
ce.Reply("Successfully logged in. Now, you may ask for `sync [--create]`.")
ce.Reply("Successfully logged in, synchronizing chats...")
go user.PostLogin()
}
type Chat struct {
Portal *Portal
LastMessageTime uint64
Contact whatsapp.Contact
}
type ChatList []Chat
func (cl ChatList) Len() int {
return len(cl)
}
func (cl ChatList) Less(i, j int) bool {
return cl[i].LastMessageTime < cl[i].LastMessageTime
}
func (cl ChatList) Swap(i, j int) {
cl[i], cl[j] = cl[j], cl[i]
}
func (user *User) PostLogin() {
user.log.Debugln("Waiting for 3 seconds for contacts to arrive")
// Hacky way to wait for chats and contacts to arrive automatically
time.Sleep(3 * time.Second)
user.log.Debugln("Waited 3 seconds:", len(user.Conn.Store.Chats), len(user.Conn.Store.Contacts))
go user.syncPortals()
go user.syncPuppets()
}
func (user *User) syncPortals() {
var chats ChatList
for _, chat := range user.Conn.Store.Chats {
ts, err := strconv.ParseUint(chat.LastMessageTime, 10, 64)
if err != nil {
user.log.Warnfln("Non-integer last message time in %s: %s", chat.Jid, chat.LastMessageTime)
continue
}
chats = append(chats, Chat{
Portal: user.GetPortalByJID(chat.Jid),
Contact: user.Conn.Store.Contacts[chat.Jid],
LastMessageTime: ts,
})
}
sort.Sort(chats)
limit := user.bridge.Config.Bridge.InitialChatSync
if limit < 0 {
limit = len(chats)
}
for i, chat := range chats {
create := (chat.LastMessageTime >= user.LastConnection && user.LastConnection > 0) || i < limit
if len(chat.Portal.MXID) > 0 || create {
chat.Portal.Sync(user, chat.Contact)
err := chat.Portal.BackfillHistory(user)
if err != nil {
chat.Portal.log.Errorln("Error backfilling history:", err)
}
}
}
}
func (user *User) syncPuppets() {
for jid, contact := range user.Conn.Store.Contacts {
if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
puppet := user.bridge.GetPuppetByJID(contact.Jid)
puppet.Sync(user, contact)
}
}
}
func (user *User) updateLastConnectionIfNecessary() {
if user.LastConnection+60 < uint64(time.Now().Unix()) {
user.UpdateLastConnection()
}
}
func (user *User) HandleError(err error) {
@ -282,6 +366,7 @@ func (user *User) HandleError(err error) {
user.ConnectionErrors = 0
user.Connected = true
_, _ = user.bridge.Bot.SendNotice(user.ManagementRoom, "Reconnected successfully")
go user.PostLogin()
return
}
user.log.Errorln("Error while trying to reconnect after disconnection:", err)
@ -324,27 +409,27 @@ func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
}
func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
}
func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
}
func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
}
func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
}
func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.Info.RemoteJid).messages <- PortalMessage{user, message, message.Info.Timestamp}
}
func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
user.GetPortalByJID(message.RemoteJid).messages <- PortalMessage{user, message}
user.GetPortalByJID(message.RemoteJid).messages <- PortalMessage{user, message, 0}
}
func (user *User) HandlePresence(info whatsappExt.Presence) {
@ -457,4 +542,9 @@ func (user *User) HandleJsonMessage(message string) {
return
}
user.log.Debugln("JSON message:", message)
user.updateLastConnectionIfNecessary()
}
func (user *User) HandleRawMessage(message *waProto.WebMessageInfo) {
user.updateLastConnectionIfNecessary()
}