From c7348f29b03ca6c2b1bcf1dfafcb9d641f3f06a5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Aug 2018 00:40:54 +0300 Subject: [PATCH] Initial desegregation of users and automatic config updating --- .gitignore | 1 + Gopkg.lock | 4 +- commands.go | 4 +- config/bridge.go | 16 +- config/config.go | 1 - config/recursivemap.go | 115 ++++++++++ config/registration.go | 2 +- config/update.go | 100 +++++++++ database/message.go | 35 +-- database/portal.go | 74 +++++-- database/puppet.go | 35 ++- database/user.go | 45 ++-- example-config.yaml | 19 +- formatting.go | 144 +++++++----- main.go | 77 +++++-- matrix.go | 26 ++- portal.go | 205 ++++++++---------- puppet.go | 92 +++----- user.go | 134 ++++++------ vendor/maunium.net/go/gomatrix/client.go | 18 +- vendor/maunium.net/go/gomatrix/events.go | 62 ++++-- .../go/mautrix-appservice/appservice.go | 27 ++- .../go/mautrix-appservice/intent.go | 6 +- .../go/mautrix-appservice/statestore.go | 39 ++-- 24 files changed, 806 insertions(+), 475 deletions(-) create mode 100644 config/recursivemap.go create mode 100644 config/update.go diff --git a/.gitignore b/.gitignore index 6a0c71d..8c8540e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ *.session *.json *.db +*.log diff --git a/Gopkg.lock b/Gopkg.lock index c23225f..41db0d9 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -123,7 +123,7 @@ ".", "format" ] - revision = "ead1f970c8f56d1854cb9eb4a54c03aa6dafd753" + revision = "42a3133c4980e4b1ea5fb52329d977f592d67cf0" [[projects]] branch = "master" @@ -141,7 +141,7 @@ branch = "master" name = "maunium.net/go/mautrix-appservice" packages = ["."] - revision = "269f2ab602126a2de94bc86a457392426cce1ab2" + revision = "37d4449056015cea5d0a4420bba595c61ad32007" [solve-meta] analyzer-name = "dep" diff --git a/commands.go b/commands.go index fbbd5bd..34255ef 100644 --- a/commands.go +++ b/commands.go @@ -48,7 +48,7 @@ type CommandEvent struct { func (ce *CommandEvent) Reply(msg string) { _, err := ce.Bot.SendNotice(string(ce.RoomID), msg) if err != nil { - ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.ID, err) + ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err) } } @@ -56,7 +56,7 @@ func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, mes args := strings.Split(message, " ") cmd := strings.ToLower(args[0]) ce := &CommandEvent{ - Bot: handler.bridge.AppService.BotIntent(), + Bot: handler.bridge.AS.BotIntent(), Bridge: handler.bridge, Handler: handler, RoomID: roomID, diff --git a/config/bridge.go b/config/bridge.go index 11c55ba..47a6860 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -56,12 +56,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { return err } -type DisplaynameTemplateArgs struct { - Displayname string -} - type UsernameTemplateArgs struct { - Receiver string UserID string } @@ -74,14 +69,9 @@ func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) string { return buf.String() } -func (bc BridgeConfig) FormatUsername(receiver types.MatrixUserID, userID types.WhatsAppID) string { +func (bc BridgeConfig) FormatUsername(userID types.WhatsAppID) string { var buf bytes.Buffer - receiver = strings.Replace(receiver, "@", "=40", 1) - receiver = strings.Replace(receiver, ":", "=3", 1) - bc.usernameTemplate.Execute(&buf, UsernameTemplateArgs{ - Receiver: receiver, - UserID: userID, - }) + bc.usernameTemplate.Execute(&buf, userID) return buf.String() } @@ -92,7 +82,7 @@ func (bc BridgeConfig) MarshalYAML() (interface{}, error) { Name: "{{.Name}}", Short: "{{.Short}}", }) - bc.UsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}") + bc.UsernameTemplate = bc.FormatUsername("{{.}}") return bc, nil } diff --git a/config/config.go b/config/config.go index 53680c9..230ecf1 100644 --- a/config/config.go +++ b/config/config.go @@ -78,7 +78,6 @@ func (config *Config) Save(path string) error { func (config *Config) MakeAppService() (*appservice.AppService, error) { as := appservice.Create() - as.LogConfig = config.Logging as.HomeserverDomain = config.Homeserver.Domain as.HomeserverURL = config.Homeserver.Address as.Host.Hostname = config.AppService.Hostname diff --git a/config/recursivemap.go b/config/recursivemap.go new file mode 100644 index 0000000..49059be --- /dev/null +++ b/config/recursivemap.go @@ -0,0 +1,115 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2018 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config + +import ( + "strings" +) + +type RecursiveMap map[interface{}]interface{} + +func (rm RecursiveMap) GetDefault(path string, defVal interface{}) interface{} { + val, ok := rm.Get(path) + if !ok { + return defVal + } + return val +} + +func (rm RecursiveMap) GetMap(path string) RecursiveMap { + val := rm.GetDefault(path, nil) + if val == nil { + return nil + } + + newRM, ok := val.(map[interface{}]interface{}) + if ok { + return RecursiveMap(newRM) + } + return nil +} + +func (rm RecursiveMap) Get(path string) (interface{}, bool) { + if index := strings.IndexRune(path, '.'); index >= 0 { + key := path[:index] + path = path[index+1:] + + submap := rm.GetMap(key) + if submap == nil { + return nil, false + } + return submap.Get(path) + } + val, ok := rm[path] + return val, ok +} + +func (rm RecursiveMap) GetIntDefault(path string, defVal int) int { + val, ok := rm.GetInt(path) + if !ok { + return defVal + } + return val +} + +func (rm RecursiveMap) GetInt(path string) (int, bool) { + val, ok := rm.Get(path) + if !ok { + return 0, ok + } + intVal, ok := val.(int) + return intVal, ok +} + +func (rm RecursiveMap) GetStringDefault(path string, defVal string) string { + val, ok := rm.GetString(path) + if !ok { + return defVal + } + return val +} + +func (rm RecursiveMap) GetString(path string) (string, bool) { + val, ok := rm.Get(path) + if !ok { + return "", ok + } + strVal, ok := val.(string) + return strVal, ok +} + +func (rm RecursiveMap) Set(path string, value interface{}) { + if index := strings.IndexRune(path, '.'); index >= 0 { + key := path[:index] + path = path[index+1:] + nextRM := rm.GetMap(key) + if nextRM == nil { + nextRM = make(RecursiveMap) + rm[key] = nextRM + } + nextRM.Set(path, value) + return + } + rm[path] = value +} + +func (rm RecursiveMap) CopyFrom(otherRM RecursiveMap, path string) { + val, ok := otherRM.Get(path) + if ok { + rm.Set(path, val) + } +} \ No newline at end of file diff --git a/config/registration.go b/config/registration.go index e1e08b5..a764444 100644 --- a/config/registration.go +++ b/config/registration.go @@ -56,7 +56,7 @@ func (config *Config) copyToRegistration(registration *appservice.Registration) registration.SenderLocalpart = config.AppService.Bot.Username userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$", - config.Bridge.FormatUsername(".+", "[0-9]+"), + config.Bridge.FormatUsername("[0-9]+"), config.Homeserver.Domain)) if err != nil { return err diff --git a/config/update.go b/config/update.go new file mode 100644 index 0000000..5c390ab --- /dev/null +++ b/config/update.go @@ -0,0 +1,100 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2018 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config + +import ( + "io/ioutil" + + "gopkg.in/yaml.v2" +) + +func Update(path, basePath string) error { + oldCfgData, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + oldCfg := make(RecursiveMap) + err = yaml.Unmarshal(oldCfgData, &oldCfg) + if err != nil { + return err + } + + baseCfgData, err := ioutil.ReadFile(basePath) + if err != nil { + return err + } + + baseCfg := make(RecursiveMap) + err = yaml.Unmarshal(baseCfgData, &baseCfg) + if err != nil { + return err + } + + err = runUpdate(oldCfg, baseCfg) + if err != nil { + return err + } + + newCfgData, err := yaml.Marshal(&baseCfg) + if err != nil { + return err + } + + return ioutil.WriteFile(path, newCfgData, 0600) +} + +func runUpdate(oldCfg, newCfg RecursiveMap) error { + cp := func(path string) { + newCfg.CopyFrom(oldCfg, path) + } + + cp("homeserver.address") + cp("homeserver.domain") + + cp("appservice.address") + cp("appservice.hostname") + cp("appservice.port") + + cp("appservice.database.type") + cp("appservice.database.uri") + cp("appservice.state_store_path") + + cp("appservice.id") + cp("appservice.bot.username") + cp("appservice.bot.displayname") + cp("appservice.bot.avatar") + + cp("appservice.bot.as_token") + cp("appservice.bot.hs_token") + + cp("bridge.username_template") + cp("bridge.displayname_template") + + cp("bridge.command_prefix") + + cp("bridge.permissions") + + cp("logging.directory") + cp("logging.file_name_format") + cp("logging.file_date_format") + cp("logging.file_mode") + cp("logging.timestamp_format") + cp("logging.print_level") + + return nil +} diff --git a/database/message.go b/database/message.go index 47d17f9..d4e3a8c 100644 --- a/database/message.go +++ b/database/message.go @@ -30,12 +30,13 @@ type MessageQuery struct { func (mq *MessageQuery) CreateTable() error { _, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message ( - owner VARCHAR(255), - jid VARCHAR(255), - mxid VARCHAR(255) NOT NULL UNIQUE, + chat_jid VARCHAR(25) NOT NULL, + chat_receiver VARCHAR(25) NOT NULL, + jid VARCHAR(255) NOT NULL, + mxid VARCHAR(255) NOT NULL UNIQUE, - PRIMARY KEY (owner, jid), - FOREIGN KEY (owner) REFERENCES user(mxid) + PRIMARY KEY (chat_jid, jid), + FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) )`) return err } @@ -47,8 +48,8 @@ func (mq *MessageQuery) New() *Message { } } -func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) { - rows, err := mq.db.Query("SELECT * FROM message WHERE owner=?", owner) +func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { + rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver) if err != nil || rows == nil { return nil } @@ -59,8 +60,8 @@ func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) { return } -func (mq *MessageQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppMessageID) *Message { - return mq.get("SELECT * FROM message WHERE owner=? AND jid=?", owner, jid) +func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message { + return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid) } func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message { @@ -79,13 +80,13 @@ type Message struct { db *Database log log.Logger - Owner types.MatrixUserID - JID types.WhatsAppMessageID - MXID types.MatrixEventID + Chat PortalKey + JID types.WhatsAppMessageID + MXID types.MatrixEventID } func (msg *Message) Scan(row Scannable) *Message { - err := row.Scan(&msg.Owner, &msg.JID, &msg.MXID) + err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID) if err != nil { if err != sql.ErrNoRows { msg.log.Errorln("Database scan failed:", err) @@ -96,17 +97,17 @@ func (msg *Message) Scan(row Scannable) *Message { } func (msg *Message) Insert() error { - _, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Owner, msg.JID, msg.MXID) + _, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID) if err != nil { - msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err) + msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err) } return err } func (msg *Message) Update() error { - _, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE owner=? AND jid=?", msg.MXID, msg.Owner, msg.JID) + _, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE chat_jid=? AND chat_receiver=? AND jid=?", msg.MXID, msg.Chat.JID, msg.Chat.Receiver, msg.JID) if err != nil { - msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err) + msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err) } return err } diff --git a/database/portal.go b/database/portal.go index 6a2ed54..1a8586d 100644 --- a/database/portal.go +++ b/database/portal.go @@ -18,11 +18,41 @@ package database import ( "database/sql" + "strings" log "maunium.net/go/maulogger" "maunium.net/go/mautrix-whatsapp/types" ) +type PortalKey struct { + JID types.WhatsAppID + Receiver types.WhatsAppID +} + +func GroupPortalKey(jid types.WhatsAppID) PortalKey { + return PortalKey{ + JID: jid, + Receiver: jid, + } +} + +func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey { + if strings.HasSuffix(jid, "@g.us") { + receiver = jid + } + return PortalKey{ + JID: jid, + Receiver: receiver, + } +} + +func (key PortalKey) String() string { + if key.Receiver == key.JID { + return key.JID + } + return key.JID + "-" + key.Receiver +} + type PortalQuery struct { db *Database log log.Logger @@ -30,16 +60,16 @@ type PortalQuery struct { func (pq *PortalQuery) CreateTable() error { _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal ( - jid VARCHAR(255), - owner VARCHAR(255), - mxid VARCHAR(255) UNIQUE, + jid VARCHAR(25), + receiver VARCHAR(25), + mxid VARCHAR(255) UNIQUE, name VARCHAR(255) NOT NULL, topic VARCHAR(255) NOT NULL, avatar VARCHAR(255) NOT NULL, - PRIMARY KEY (jid, owner), - FOREIGN KEY (owner) REFERENCES user(mxid) + PRIMARY KEY (jid, receiver), + FOREIGN KEY (receiver) REFERENCES user(mxid) )`) return err } @@ -51,8 +81,8 @@ func (pq *PortalQuery) New() *Portal { } } -func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) { - rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner) +func (pq *PortalQuery) GetAll() (portals []*Portal) { + rows, err := pq.db.Query("SELECT * FROM portal") if err != nil || rows == nil { return nil } @@ -63,8 +93,8 @@ func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) { return } -func (pq *PortalQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppID) *Portal { - return pq.get("SELECT * FROM portal WHERE jid=? AND owner=?", jid, owner) +func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { + return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver) } func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal { @@ -83,9 +113,8 @@ type Portal struct { db *Database log log.Logger - JID types.WhatsAppID - MXID types.MatrixRoomID - Owner types.MatrixUserID + Key PortalKey + MXID types.MatrixRoomID Name string Topic string @@ -93,7 +122,7 @@ type Portal struct { } func (portal *Portal) Scan(row Scannable) *Portal { - err := row.Scan(&portal.JID, &portal.Owner, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar) + err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar) if err != nil { if err != sql.ErrNoRows { portal.log.Errorln("Database scan failed:", err) @@ -103,15 +132,18 @@ func (portal *Portal) Scan(row Scannable) *Portal { return portal } -func (portal *Portal) Insert() error { - var mxid *string +func (portal *Portal) mxidPtr() *string { if len(portal.MXID) > 0 { - mxid = &portal.MXID + return &portal.MXID } + return nil +} + +func (portal *Portal) Insert() error { _, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)", - portal.JID, portal.Owner, mxid, portal.Name, portal.Topic, portal.Avatar) + portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar) if err != nil { - portal.log.Warnfln("Failed to insert %s->%s: %v", portal.JID, portal.Owner, err) + portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) } return err } @@ -121,10 +153,10 @@ func (portal *Portal) Update() error { if len(portal.MXID) > 0 { mxid = &portal.MXID } - _, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND owner=?", - mxid, portal.Name, portal.Topic, portal.Avatar, portal.JID, portal.Owner) + _, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?", + mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver) if err != nil { - portal.log.Warnfln("Failed to update %s->%s: %v", portal.JID, portal.Owner, err) + portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) } return err } diff --git a/database/puppet.go b/database/puppet.go index 434fbe0..401a5c8 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -30,13 +30,9 @@ type PuppetQuery struct { func (pq *PuppetQuery) CreateTable() error { _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet ( - jid VARCHAR(255), - receiver VARCHAR(255), - + jid VARCHAR(25) PRIMARY KEY, displayname VARCHAR(255), - avatar VARCHAR(255), - - PRIMARY KEY(jid, receiver) + avatar VARCHAR(255) )`) return err } @@ -48,8 +44,8 @@ func (pq *PuppetQuery) New() *Puppet { } } -func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) { - rows, err := pq.db.Query("SELECT * FROM puppet WHERE receiver=%s") +func (pq *PuppetQuery) GetAll() (puppets []*Puppet) { + rows, err := pq.db.Query("SELECT * FROM puppet") if err != nil || rows == nil { return nil } @@ -60,8 +56,8 @@ func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) { return } -func (pq *PuppetQuery) Get(jid types.WhatsAppID, receiver types.MatrixUserID) *Puppet { - row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=? AND receiver=?", jid, receiver) +func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet { + row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid) if row == nil { return nil } @@ -72,15 +68,13 @@ type Puppet struct { db *Database log log.Logger - JID types.WhatsAppID - Receiver types.MatrixUserID - + JID types.WhatsAppID Displayname string Avatar string } func (puppet *Puppet) Scan(row Scannable) *Puppet { - err := row.Scan(&puppet.JID, &puppet.Receiver, &puppet.Displayname, &puppet.Avatar) + err := row.Scan(&puppet.JID, &puppet.Displayname, &puppet.Avatar) if err != nil { if err != sql.ErrNoRows { puppet.log.Errorln("Database scan failed:", err) @@ -91,20 +85,19 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet { } func (puppet *Puppet) Insert() error { - _, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)", - puppet.JID, puppet.Receiver, puppet.Displayname, puppet.Avatar) + _, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?)", + puppet.JID, puppet.Displayname, puppet.Avatar) if err != nil { - puppet.log.Errorfln("Failed to insert %s->%s: %v", puppet.JID, puppet.Receiver, err) + puppet.log.Errorfln("Failed to insert %s: %v", puppet.JID, err) } return err } func (puppet *Puppet) Update() error { - _, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=? AND receiver=?", - puppet.Displayname, puppet.Avatar, - puppet.JID, puppet.Receiver) + _, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=?", + puppet.Displayname, puppet.Avatar, puppet.JID) if err != nil { - puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, puppet.Receiver, err) + puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, err) } return err } diff --git a/database/user.go b/database/user.go index 15dc793..b650cb3 100644 --- a/database/user.go +++ b/database/user.go @@ -32,6 +32,7 @@ type UserQuery struct { func (uq *UserQuery) CreateTable() error { _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user ( mxid VARCHAR(255) PRIMARY KEY, + jid VARCHAR(25) UNIQUE, management_room VARCHAR(255), @@ -64,7 +65,7 @@ func (uq *UserQuery) GetAll() (users []*User) { return } -func (uq *UserQuery) Get(userID types.MatrixUserID) *User { +func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User { row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID) if row == nil { return nil @@ -72,18 +73,27 @@ func (uq *UserQuery) Get(userID types.MatrixUserID) *User { return uq.New().Scan(row) } +func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User { + row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", userID) + if row == nil { + return nil + } + return uq.New().Scan(row) +} + type User struct { db *Database log log.Logger - ID types.MatrixUserID + MXID types.MatrixUserID + JID types.WhatsAppID ManagementRoom types.MatrixRoomID Session *whatsapp.Session } func (user *User) Scan(row Scannable) *User { sess := whatsapp.Session{} - err := row.Scan(&user.ID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken, + err := row.Scan(&user.MXID, &user.JID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken, &sess.EncKey, &sess.MacKey, &sess.Wid) if err != nil { if err != sql.ErrNoRows { @@ -99,23 +109,32 @@ func (user *User) Scan(row Scannable) *User { return user } -func (user *User) Insert() error { - var sess whatsapp.Session +func (user *User) jidPtr() *string { + if len(user.JID) > 0 { + return &user.JID + } + return nil +} + +func (user *User) sessionUnptr() (sess whatsapp.Session) { if user.Session != nil { sess = *user.Session } - _, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.ID, user.ManagementRoom, + return +} + +func (user *User) Insert() error { + sess := user.sessionUnptr() + _, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(), user.ManagementRoom, sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid) return err } func (user *User) Update() error { - var sess whatsapp.Session - if user.Session != nil { - sess = *user.Session - } - _, err := user.db.Exec("UPDATE user SET management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?", - user.ManagementRoom, - sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, user.ID) + sess := user.sessionUnptr() + _, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?", + user.jidPtr(), user.ManagementRoom, + sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, + user.MXID) return err } diff --git a/example-config.yaml b/example-config.yaml index d3a1f60..bbc609c 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -21,7 +21,6 @@ appservice: type: sqlite3 # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string uri: mautrix-whatsapp.db - # Path to the Matrix room state store. state_store_path: ./mx-state.json @@ -43,15 +42,15 @@ appservice: # Bridge config. Currently unused. bridge: # Localpart template of MXIDs for WhatsApp users. - # {{.Receiver}} is replaced with the WhatsApp user ID of the Matrix user receiving messages. - # {{.UserID}} is replaced with the user ID of the WhatsApp user. - username_template: "whatsapp_{{.Receiver}}_{{.UserID}}" + # {{.}} is replaced with the phone number of the WhatsApp user. + username_template: whatsapp_{{.}} # Displayname template for WhatsApp users. - # {{.Name}} - display name - # {{.Short}} - short display name (usually first name) - # {{.Notify}} - nickname (maybe set by the target WhatsApp user) + # {{.Notify}} - nickname set by the WhatsApp user # {{.Jid}} - phone number (international format) - displayname_template: "{{if .Name}}{{.Name}}{{else if .Notify}}{{.Notify}}{{else if .Short}}{{.Short}}{{else}}{{.Jid}}{{end}}" + # The following variables are also available, but will cause problems on multi-user instances: + # {{.Name}} - display name from contact list + # {{.Short}} - short display name from contact list + displayname_template: "{{if .Notify}}{{.Notify}}{{else}}{{.Jid}}{{end}} (WA)" # The prefix for commands. Only required in non-management rooms. command_prefix: "!wa" @@ -72,8 +71,8 @@ bridge: logging: # The directory for log files. Will be created if not found. directory: ./logs - # Available variables: .date for the file date and .index for different log files on the same day. - file_name_format: "{{.date}}-{{.index}.log" + # Available variables: .Date for the file date and .Index for different log files on the same day. + file_name_format: "{{.Date}}-{{.Index}}.log" # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants file_date_format: 2006-01-02 # Log file permissions. diff --git a/formatting.go b/formatting.go index 7d04cc0..bd3eb3c 100644 --- a/formatting.go +++ b/formatting.go @@ -18,58 +18,71 @@ package main import ( "fmt" + "html" "regexp" "strings" + "maunium.net/go/gomatrix" "maunium.net/go/gomatrix/format" "maunium.net/go/mautrix-whatsapp/whatsapp-ext" ) -func (user *User) newHTMLParser() *format.HTMLParser { - return &format.HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - - PillConverter: func(mxid, eventID string) string { - if mxid[0] == '@' { - puppet := user.GetPuppetByMXID(mxid) - fmt.Println(mxid, puppet) - if puppet != nil { - return "@" + puppet.PhoneNumber() - } - } - return mxid - }, - BoldConverter: func(text string) string { - return fmt.Sprintf("*%s*", text) - }, - ItalicConverter: func(text string) string { - return fmt.Sprintf("_%s_", text) - }, - StrikethroughConverter: func(text string) string { - return fmt.Sprintf("~%s~", text) - }, - MonospaceConverter: func(text string) string { - return fmt.Sprintf("```%s```", text) - }, - MonospaceBlockConverter: func(text string) string { - return fmt.Sprintf("```%s```", text) - }, - } -} - var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)") var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)") var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)") var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```") var mentionRegex = regexp.MustCompile("@[0-9]+") -func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regexp.Regexp]func(string) string, map[*regexp.Regexp]func(string) string) { - return map[*regexp.Regexp]string{ - italicRegex: "$1$2$3", - boldRegex: "$1$2$3", - strikethroughRegex: "$1$2$3", - }, map[*regexp.Regexp]func(string) string{ +type Formatter struct { + bridge *Bridge + + matrixHTMLParser *format.HTMLParser + + waReplString map[*regexp.Regexp]string + waReplFunc map[*regexp.Regexp]func(string) string + waReplFuncText map[*regexp.Regexp]func(string) string +} + +func NewFormatter(bridge *Bridge) *Formatter { + formatter := &Formatter{ + bridge: bridge, + matrixHTMLParser: &format.HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + + PillConverter: func(mxid, eventID string) string { + if mxid[0] == '@' { + puppet := bridge.GetPuppetByMXID(mxid) + fmt.Println(mxid, puppet) + if puppet != nil { + return "@" + puppet.PhoneNumber() + } + } + return mxid + }, + BoldConverter: func(text string) string { + return fmt.Sprintf("*%s*", text) + }, + ItalicConverter: func(text string) string { + return fmt.Sprintf("_%s_", text) + }, + StrikethroughConverter: func(text string) string { + return fmt.Sprintf("~%s~", text) + }, + MonospaceConverter: func(text string) string { + return fmt.Sprintf("```%s```", text) + }, + MonospaceBlockConverter: func(text string) string { + return fmt.Sprintf("```%s```", text) + }, + }, + waReplString: map[*regexp.Regexp]string{ + italicRegex: "$1$2$3", + boldRegex: "$1$2$3", + strikethroughRegex: "$1$2$3", + }, + } + formatter.waReplFunc = map[*regexp.Regexp]func(string) string{ codeBlockRegex: func(str string) string { str = str[3 : len(str)-3] if strings.ContainsRune(str, '\n') { @@ -78,18 +91,47 @@ func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regex return fmt.Sprintf("%s", str) }, mentionRegex: func(str string) string { - jid := str[1:] + whatsappExt.NewUserSuffix - puppet := user.GetPuppetByJID(jid) - mxid := puppet.MXID - if jid == user.JID() { - mxid = user.ID - } - return fmt.Sprintf(`%s`, mxid, puppet.Displayname) - }, - }, map[*regexp.Regexp]func(string)string { - mentionRegex: func(str string) string { - puppet := user.GetPuppetByJID(str[1:] + whatsappExt.NewUserSuffix) - return puppet.Displayname + mxid, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix) + return fmt.Sprintf(`%s`, mxid, displayname) }, } + formatter.waReplFuncText = map[*regexp.Regexp]func(string) string{ + mentionRegex: func(str string) string { + _, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix) + return displayname + }, + } + return formatter +} + +func (formatter *Formatter) getMatrixInfoByJID(jid string) (mxid, displayname string) { + if user := formatter.bridge.GetUserByJID(jid); user != nil { + mxid = user.MXID + displayname = user.MXID + } else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil { + mxid = puppet.MXID + displayname = puppet.Displayname + } + return +} + +func (formatter *Formatter) ParseWhatsApp(content *gomatrix.Content) { + output := html.EscapeString(content.Body) + for regex, replacement := range formatter.waReplString { + output = regex.ReplaceAllString(output, replacement) + } + for regex, replacer := range formatter.waReplFunc { + output = regex.ReplaceAllStringFunc(output, replacer) + } + if output != content.Body { + content.FormattedBody = output + content.Format = gomatrix.FormatHTML + for regex, replacer := range formatter.waReplFuncText { + content.Body = regex.ReplaceAllStringFunc(content.Body, replacer) + } + } +} + +func (formatter *Formatter) ParseMatrix(html string) string { + return formatter.matrixHTMLParser.Parse(html) } diff --git a/main.go b/main.go index c5a05ac..50bac7e 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "os/signal" + "sync" "syscall" flag "maunium.net/go/mauflag" @@ -31,6 +32,7 @@ import ( ) var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() +var baseConfigPath = flag.MakeFull("b", "base-config", "The path to the example config file.", "example-config.yaml").String() var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() var wantHelp, _ = flag.MakeHelpFlag() @@ -58,29 +60,47 @@ func (bridge *Bridge) GenerateRegistration() { } type Bridge struct { - AppService *appservice.AppService + AS *appservice.AppService EventProcessor *appservice.EventProcessor MatrixHandler *MatrixHandler Config *config.Config DB *database.Database Log log.Logger + StateStore *AutosavingStateStore + Bot *appservice.IntentAPI + Formatter *Formatter - StateStore *AutosavingStateStore - - users map[types.MatrixUserID]*User - managementRooms map[types.MatrixRoomID]*User + usersByMXID map[types.MatrixUserID]*User + usersByJID map[types.WhatsAppID]*User + usersLock sync.Mutex + managementRooms map[types.MatrixRoomID]*User + managementRoomsLock sync.Mutex + portalsByMXID map[types.MatrixRoomID]*Portal + portalsByJID map[database.PortalKey]*Portal + portalsLock sync.Mutex + puppets map[types.WhatsAppID]*Puppet + puppetsLock sync.Mutex } func NewBridge() *Bridge { bridge := &Bridge{ - users: make(map[types.MatrixUserID]*User), + usersByMXID: make(map[types.MatrixUserID]*User), + usersByJID: make(map[types.WhatsAppID]*User), managementRooms: make(map[types.MatrixRoomID]*User), + portalsByMXID: make(map[types.MatrixRoomID]*Portal), + portalsByJID: make(map[database.PortalKey]*Portal), + puppets: make(map[types.WhatsAppID]*Puppet), } - var err error + err := config.Update(*configPath, *baseConfigPath) + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to update config:", err) + os.Exit(10) + } + bridge.Config, err = config.Load(*configPath) if err != nil { fmt.Fprintln(os.Stderr, "Failed to load config:", err) - os.Exit(10) + os.Exit(11) } return bridge } @@ -88,46 +108,55 @@ func NewBridge() *Bridge { func (bridge *Bridge) Init() { var err error - bridge.AppService, err = bridge.Config.MakeAppService() + bridge.AS, err = bridge.Config.MakeAppService() if err != nil { fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err) - os.Exit(11) + os.Exit(12) } - bridge.AppService.Init() - bridge.Log = bridge.AppService.Log + bridge.AS.Init() + bridge.Bot = bridge.AS.BotIntent() + + bridge.Log = log.Create() + bridge.Config.Logging.Configure(bridge.Log) log.DefaultLogger = bridge.Log.(*log.BasicLogger) - bridge.AppService.Log = log.Sub("Matrix") + err = log.OpenFile() + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to open log file:", err) + os.Exit(13) + } + bridge.AS.Log = log.Sub("Matrix") bridge.Log.Debugln("Initializing state store") bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore) err = bridge.StateStore.Load() if err != nil { bridge.Log.Fatalln("Failed to load state store:", err) - os.Exit(12) + os.Exit(14) } - bridge.AppService.StateStore = bridge.StateStore + bridge.AS.StateStore = bridge.StateStore bridge.Log.Debugln("Initializing database") bridge.DB, err = database.New(bridge.Config.AppService.Database.URI) if err != nil { bridge.Log.Fatalln("Failed to initialize database:", err) - os.Exit(13) + os.Exit(15) } bridge.Log.Debugln("Initializing Matrix event processor") - bridge.EventProcessor = appservice.NewEventProcessor(bridge.AppService) + bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS) bridge.Log.Debugln("Initializing Matrix event handler") bridge.MatrixHandler = NewMatrixHandler(bridge) + bridge.Formatter = NewFormatter(bridge) } func (bridge *Bridge) Start() { err := bridge.DB.CreateTables() if err != nil { bridge.Log.Fatalln("Failed to create database tables:", err) - os.Exit(14) + os.Exit(16) } bridge.Log.Debugln("Starting application service HTTP server") - go bridge.AppService.Start() + go bridge.AS.Start() bridge.Log.Debugln("Starting event processor") go bridge.EventProcessor.Start() go bridge.UpdateBotProfile() @@ -140,18 +169,18 @@ func (bridge *Bridge) UpdateBotProfile() { var err error if botConfig.Avatar == "remove" { - err = bridge.AppService.BotIntent().SetAvatarURL("") + err = bridge.AS.BotIntent().SetAvatarURL("") } else if len(botConfig.Avatar) > 0 { - err = bridge.AppService.BotIntent().SetAvatarURL(botConfig.Avatar) + err = bridge.AS.BotIntent().SetAvatarURL(botConfig.Avatar) } if err != nil { bridge.Log.Warnln("Failed to update bot avatar:", err) } if botConfig.Displayname == "remove" { - err = bridge.AppService.BotIntent().SetDisplayName("") + err = bridge.AS.BotIntent().SetDisplayName("") } else if len(botConfig.Avatar) > 0 { - err = bridge.AppService.BotIntent().SetDisplayName(botConfig.Displayname) + err = bridge.AS.BotIntent().SetDisplayName(botConfig.Displayname) } if err != nil { bridge.Log.Warnln("Failed to update bot displayname:", err) @@ -165,7 +194,7 @@ func (bridge *Bridge) StartUsers() { } func (bridge *Bridge) Stop() { - bridge.AppService.Stop() + bridge.AS.Stop() bridge.EventProcessor.Stop() err := bridge.StateStore.Save() if err != nil { diff --git a/matrix.go b/matrix.go index adb76d4..de63727 100644 --- a/matrix.go +++ b/matrix.go @@ -35,7 +35,7 @@ type MatrixHandler struct { func NewMatrixHandler(bridge *Bridge) *MatrixHandler { handler := &MatrixHandler{ bridge: bridge, - as: bridge.AppService, + as: bridge.AS, log: bridge.Log.Sub("Matrix"), cmd: NewCommandHandler(bridge), } @@ -50,7 +50,7 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler { func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) { intent := mx.as.BotIntent() - user := mx.bridge.GetUser(evt.Sender) + user := mx.bridge.GetUserByMXID(evt.Sender) if user == nil { return } @@ -85,7 +85,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) { for mxid, _ := range members.Joined { if mxid == intent.UserID || mxid == evt.Sender { continue - } else if _, _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok { + } else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok { hasPuppets = true continue } @@ -96,7 +96,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) { } if !hasPuppets { - user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) + user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) user.SetManagementRoom(types.MatrixRoomID(resp.RoomID)) intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.") mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender) @@ -110,12 +110,12 @@ func (mx *MatrixHandler) HandleMembership(evt *gomatrix.Event) { } func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) { - user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) - if user == nil || !user.Whitelisted { + user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) + if user == nil || !user.Whitelisted || !user.IsLoggedIn() { return } - portal := user.GetPortalByMXID(evt.RoomID) + portal := mx.bridge.GetPortalByMXID(evt.RoomID) if portal == nil || portal.IsPrivateChat() { return } @@ -124,7 +124,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) { var err error switch evt.Type { case gomatrix.StateRoomName: - resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.JID) + resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID) case gomatrix.StateRoomAvatar: return case gomatrix.StateTopic: @@ -140,7 +140,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) { func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) { roomID := types.MatrixRoomID(evt.RoomID) - user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) + user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) if !user.Whitelisted { return @@ -158,8 +158,12 @@ func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) { } } - portal := user.GetPortalByMXID(roomID) + if !user.IsLoggedIn() { + return + } + + portal := mx.bridge.GetPortalByMXID(roomID) if portal != nil { - portal.HandleMatrixMessage(evt) + portal.HandleMatrixMessage(user, evt) } } diff --git a/portal.go b/portal.go index 1538431..e597d1c 100644 --- a/portal.go +++ b/portal.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/hex" "fmt" - "html" "image" "image/gif" "image/jpeg" @@ -41,57 +40,56 @@ import ( "maunium.net/go/mautrix-whatsapp/whatsapp-ext" ) -func (user *User) GetPortalByMXID(mxid types.MatrixRoomID) *Portal { - user.portalsLock.Lock() - defer user.portalsLock.Unlock() - portal, ok := user.portalsByMXID[mxid] +func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal { + bridge.portalsLock.Lock() + defer bridge.portalsLock.Unlock() + portal, ok := bridge.portalsByMXID[mxid] if !ok { - dbPortal := user.bridge.DB.Portal.GetByMXID(mxid) - if dbPortal == nil || dbPortal.Owner != user.ID { + dbPortal := bridge.DB.Portal.GetByMXID(mxid) + if dbPortal == nil { return nil } - portal = user.NewPortal(dbPortal) - user.portalsByJID[portal.JID] = portal + portal = bridge.NewPortal(dbPortal) + bridge.portalsByJID[portal.Key] = portal if len(portal.MXID) > 0 { - user.portalsByMXID[portal.MXID] = portal + bridge.portalsByMXID[portal.MXID] = portal } } return portal } -func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal { - user.portalsLock.Lock() - defer user.portalsLock.Unlock() - portal, ok := user.portalsByJID[jid] +func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal { + bridge.portalsLock.Lock() + defer bridge.portalsLock.Unlock() + portal, ok := bridge.portalsByJID[key] if !ok { - dbPortal := user.bridge.DB.Portal.GetByJID(user.ID, jid) + dbPortal := bridge.DB.Portal.GetByJID(key) if dbPortal == nil { - dbPortal = user.bridge.DB.Portal.New() - dbPortal.JID = jid - dbPortal.Owner = user.ID + dbPortal = bridge.DB.Portal.New() + dbPortal.Key = key dbPortal.Insert() } - portal = user.NewPortal(dbPortal) - user.portalsByJID[portal.JID] = portal + portal = bridge.NewPortal(dbPortal) + bridge.portalsByJID[portal.Key] = portal if len(portal.MXID) > 0 { - user.portalsByMXID[portal.MXID] = portal + bridge.portalsByMXID[portal.MXID] = portal } } return portal } -func (user *User) GetAllPortals() []*Portal { - user.portalsLock.Lock() - defer user.portalsLock.Unlock() - dbPortals := user.bridge.DB.Portal.GetAll(user.ID) +func (bridge *Bridge) GetAllPortals() []*Portal { + bridge.portalsLock.Lock() + defer bridge.portalsLock.Unlock() + dbPortals := bridge.DB.Portal.GetAll() output := make([]*Portal, len(dbPortals)) for index, dbPortal := range dbPortals { - portal, ok := user.portalsByJID[dbPortal.JID] + portal, ok := bridge.portalsByJID[dbPortal.Key] if !ok { - portal = user.NewPortal(dbPortal) - user.portalsByJID[dbPortal.JID] = portal + portal = bridge.NewPortal(dbPortal) + bridge.portalsByJID[portal.Key] = portal if len(dbPortal.MXID) > 0 { - user.portalsByMXID[dbPortal.MXID] = portal + bridge.portalsByMXID[dbPortal.MXID] = portal } } output[index] = portal @@ -99,19 +97,17 @@ func (user *User) GetAllPortals() []*Portal { return output } -func (user *User) NewPortal(dbPortal *database.Portal) *Portal { +func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal { return &Portal{ Portal: dbPortal, - user: user, - bridge: user.bridge, - log: user.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.JID)), + bridge: bridge, + log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)), } } type Portal struct { *database.Portal - user *User bridge *Bridge log log.Logger @@ -126,9 +122,16 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) { changed = true } for _, participant := range metadata.Participants { - puppet := portal.user.GetPuppetByJID(participant.JID) + puppet := portal.bridge.GetPuppetByJID(participant.JID) puppet.Intent().EnsureJoined(portal.MXID) + user := portal.bridge.GetUserByJID(participant.JID) + if !portal.bridge.AS.StateStore.IsInvited(portal.MXID, user.MXID) { + portal.MainIntent().InviteUser(portal.MXID, &gomatrix.ReqInviteUser{ + UserID: user.MXID, + }) + } + expectedLevel := 0 if participant.IsSuperAdmin { expectedLevel = 95 @@ -136,9 +139,8 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) { expectedLevel = 50 } changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed - - if participant.JID == portal.user.JID() { - changed = levels.EnsureUserLevel(portal.user.ID, expectedLevel) || changed + if user != nil { + changed = levels.EnsureUserLevel(user.MXID, expectedLevel) || changed } } if changed { @@ -146,10 +148,10 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) { } } -func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { +func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInfo) bool { if avatar == nil { var err error - avatar, err = portal.user.Conn.GetProfilePicThumb(portal.JID) + avatar, err = user.Conn.GetProfilePicThumb(portal.Key.JID) if err != nil { portal.log.Errorln(err) return false @@ -184,7 +186,7 @@ func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool { if portal.Name != name { - intent := portal.user.GetPuppetByJID(setBy).Intent() + intent := portal.bridge.GetPuppetByJID(setBy).Intent() _, err := intent.SetRoomName(portal.MXID, name) if err == nil { portal.Name = name @@ -197,7 +199,7 @@ func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool { func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool { if portal.Topic != topic { - intent := portal.user.GetPuppetByJID(setBy).Intent() + intent := portal.bridge.GetPuppetByJID(setBy).Intent() _, err := intent.SetRoomTopic(portal.MXID, topic) if err == nil { portal.Topic = topic @@ -208,8 +210,8 @@ func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool { return false } -func (portal *Portal) UpdateMetadata() bool { - metadata, err := portal.user.Conn.GetGroupMetaData(portal.JID) +func (portal *Portal) UpdateMetadata(user *User) bool { + metadata, err := user.Conn.GetGroupMetaData(portal.Key.JID) if err != nil { portal.log.Errorln(err) return false @@ -221,25 +223,23 @@ func (portal *Portal) UpdateMetadata() bool { return update } -func (portal *Portal) Sync(contact whatsapp.Contact) { +func (portal *Portal) Sync(user *User, contact whatsapp.Contact) { + if portal.IsPrivateChat() { + return + } + if len(portal.MXID) == 0 { - if !portal.IsPrivateChat() { - portal.Name = contact.Name - } - err := portal.CreateMatrixRoom() + portal.Name = contact.Name + err := portal.CreateMatrixRoom([]string{user.MXID}) if err != nil { portal.log.Errorln("Failed to create portal room:", err) return } } - if portal.IsPrivateChat() { - return - } - update := false - update = portal.UpdateMetadata() || update - update = portal.UpdateAvatar(nil) || update + update = portal.UpdateMetadata(user) || update + update = portal.UpdateAvatar(user, nil) || update if update { portal.Update() } @@ -277,11 +277,12 @@ func (portal *Portal) ChangeAdminStatus(jids []string, setAdmin bool) { } changed := false for _, jid := range jids { - puppet := portal.user.GetPuppetByJID(jid) + puppet := portal.bridge.GetPuppetByJID(jid) changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed - if jid == portal.user.JID() { - changed = levels.EnsureUserLevel(portal.user.ID, newLevel) || changed + user := portal.bridge.GetUserByJID(jid) + if user != nil { + changed = levels.EnsureUserLevel(user.MXID, newLevel) || changed } } if changed { @@ -312,15 +313,15 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) { newLevel = 50 } changed := false - changed = levels.EnsureEventLevel(gomatrix.StateRoomName, true, newLevel) || changed - changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, true, newLevel) || changed - changed = levels.EnsureEventLevel(gomatrix.StateTopic, true, newLevel) || changed + changed = levels.EnsureEventLevel(gomatrix.StateRoomName, newLevel) || changed + changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, newLevel) || changed + changed = levels.EnsureEventLevel(gomatrix.StateTopic, newLevel) || changed if changed { portal.MainIntent().SetPowerLevels(portal.MXID, levels) } } -func (portal *Portal) CreateMatrixRoom() error { +func (portal *Portal) CreateMatrixRoom(invite []string) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if len(portal.MXID) > 0 { @@ -330,7 +331,6 @@ func (portal *Portal) CreateMatrixRoom() error { name := portal.Name topic := portal.Topic isPrivateChat := false - invite := []string{portal.user.ID} if portal.IsPrivateChat() { name = "" topic = "WhatsApp private chat" @@ -360,18 +360,18 @@ func (portal *Portal) CreateMatrixRoom() error { } func (portal *Portal) IsPrivateChat() bool { - return strings.HasSuffix(portal.JID, whatsappExt.NewUserSuffix) + return strings.HasSuffix(portal.Key.JID, whatsappExt.NewUserSuffix) } func (portal *Portal) MainIntent() *appservice.IntentAPI { if portal.IsPrivateChat() { - return portal.user.GetPuppetByJID(portal.JID).Intent() + return portal.bridge.GetPuppetByJID(portal.Key.JID).Intent() } - return portal.bridge.AppService.BotIntent() + return portal.bridge.AS.BotIntent() } func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool { - msg := portal.bridge.DB.Message.GetByJID(portal.Owner, id) + msg := portal.bridge.DB.Message.GetByJID(portal.Key, id) if msg != nil { return true } @@ -380,7 +380,7 @@ func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool { func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) { msg := portal.bridge.DB.Message.New() - msg.Owner = portal.Owner + msg.Chat = portal.Key msg.JID = jid msg.MXID = mxid msg.Insert() @@ -392,7 +392,7 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In // TODO handle own messages in private chats properly return nil } - return portal.user.GetPuppetByJID(portal.user.JID()).Intent() + return portal.bridge.GetPuppetByJID(portal.Key.Receiver).Intent() } else if portal.IsPrivateChat() { return portal.MainIntent() } else if len(info.SenderJid) == 0 { @@ -402,14 +402,14 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In return nil } } - return portal.user.GetPuppetByJID(info.SenderJid).Intent() + return portal.bridge.GetPuppetByJID(info.SenderJid).Intent() } func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) { if len(info.QuotedMessageID) == 0 { return } - message := portal.bridge.DB.Message.GetByJID(portal.Owner, info.QuotedMessageID) + message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID) if message != nil { event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID) if err != nil { @@ -421,29 +421,12 @@ func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageI return } -func (portal *Portal) FormatWhatsAppMessage(content *gomatrix.Content) { - output := html.EscapeString(content.Body) - for regex, replacement := range portal.user.waReplString { - output = regex.ReplaceAllString(output, replacement) - } - for regex, replacer := range portal.user.waReplFunc { - output = regex.ReplaceAllStringFunc(output, replacer) - } - if output != content.Body { - content.FormattedBody = output - content.Format = gomatrix.FormatHTML - for regex, replacer := range portal.user.waReplFuncText { - content.Body = regex.ReplaceAllStringFunc(content.Body, replacer) - } - } -} - -func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) { +func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) { if portal.IsDuplicate(message.Info.Id) { return } - err := portal.CreateMatrixRoom() + err := portal.CreateMatrixRoom([]string{source.MXID}) if err != nil { portal.log.Errorln("Failed to create portal room:", err) return @@ -459,7 +442,7 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) { MsgType: gomatrix.MsgText, } - portal.FormatWhatsAppMessage(content) + portal.bridge.Formatter.ParseWhatsApp(content) portal.SetReply(content, message.Info) intent.UserTyping(portal.MXID, false, 0) @@ -472,12 +455,12 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) { portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID) } -func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) { +func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) { if portal.IsDuplicate(info.Id) { return } - err := portal.CreateMatrixRoom() + err := portal.CreateMatrixRoom([]string{source.MXID}) if err != nil { portal.log.Errorln("Failed to create portal room:", err) return @@ -559,7 +542,7 @@ func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbn MsgType: gomatrix.MsgNotice, } - portal.FormatWhatsAppMessage(captionContent) + portal.bridge.Formatter.ParseWhatsApp(captionContent) _, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts) if err != nil { @@ -612,7 +595,7 @@ func (portal *Portal) downloadThumbnail(evt *gomatrix.Event) []byte { return buf.Bytes() } -func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload { +func (portal *Portal) preprocessMatrixMedia(sender *User, evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload { if evt.Content.Info == nil { evt.Content.Info = &gomatrix.FileInfo{} } @@ -630,7 +613,7 @@ func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whats return nil } - url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := portal.user.Conn.Upload(bytes.NewReader(content), mediaType) + url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType) if err != nil { portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err) return nil @@ -657,8 +640,8 @@ type MediaUpload struct { Thumbnail []byte } -func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessageInfo { - node, err := portal.user.Conn.LoadMessagesBefore(portal.JID, jid, 1) +func (portal *Portal) GetMessage(user *User, jid types.WhatsAppMessageID) *waProto.WebMessageInfo { + node, err := user.Conn.LoadMessagesBefore(portal.Key.JID, jid, 1) if err != nil { return nil } @@ -670,7 +653,7 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag if !ok { return nil } - node, err = portal.user.Conn.LoadMessagesAfter(portal.JID, msg.GetKey().GetId(), 1) + node, err = user.Conn.LoadMessagesAfter(portal.Key.JID, msg.GetKey().GetId(), 1) if err != nil { return nil } @@ -682,7 +665,11 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag return msg } -func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { +func (portal *Portal) HandleMatrixMessage(sender *User, evt *gomatrix.Event) { + if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver { + return + } + ts := uint64(evt.Timestamp / 1000) status := waProto.WebMessageInfo_ERROR fromMe := true @@ -690,7 +677,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { Key: &waProto.MessageKey{ FromMe: &fromMe, Id: makeMessageID(), - RemoteJid: &portal.JID, + RemoteJid: &portal.Key.JID, }, MessageTimestamp: &ts, Message: &waProto.Message{}, @@ -702,12 +689,12 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { evt.Content.RemoveReplyFallback() msg := portal.bridge.DB.Message.GetByMXID(replyToID) if msg != nil { - origMsg := portal.GetMessage(msg.JID) + origMsg := portal.GetMessage(sender, msg.JID) if origMsg != nil { ctxInfo.StanzaId = &msg.JID replyMsgSender := origMsg.GetParticipant() if origMsg.GetKey().GetFromMe() { - replyMsgSender = portal.user.JID() + replyMsgSender = sender.JID } ctxInfo.Participant = &replyMsgSender ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message} @@ -719,7 +706,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { case gomatrix.MsgText, gomatrix.MsgEmote: text := evt.Content.Body if evt.Content.Format == gomatrix.FormatHTML { - text = portal.user.htmlParser.Parse(evt.Content.FormattedBody) + text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody) } if evt.Content.MsgType == gomatrix.MsgEmote { text = "/me " + text @@ -737,7 +724,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { info.Message.Conversation = &text } case gomatrix.MsgImage: - media := portal.preprocessMatrixMedia(evt, whatsapp.MediaImage) + media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaImage) if media == nil { return } @@ -752,7 +739,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { FileLength: &media.FileLength, } case gomatrix.MsgVideo: - media := portal.preprocessMatrixMedia(evt, whatsapp.MediaVideo) + media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaVideo) if media == nil { return } @@ -769,7 +756,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { FileLength: &media.FileLength, } case gomatrix.MsgAudio: - media := portal.preprocessMatrixMedia(evt, whatsapp.MediaAudio) + media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaAudio) if media == nil { return } @@ -784,7 +771,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { FileLength: &media.FileLength, } case gomatrix.MsgFile: - media := portal.preprocessMatrixMedia(evt, whatsapp.MediaDocument) + media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaDocument) if media == nil { return } @@ -800,7 +787,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { portal.log.Debugln("Unhandled Matrix event:", evt) return } - err = portal.user.Conn.Send(info) + err = sender.Conn.Send(info) portal.MarkHandled(info.GetKey().GetId(), evt.ID) if err != nil { portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err) diff --git a/puppet.go b/puppet.go index 5b58511..587cabc 100644 --- a/puppet.go +++ b/puppet.go @@ -30,105 +30,83 @@ import ( "maunium.net/go/mautrix-whatsapp/whatsapp-ext" ) -func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.MatrixUserID, types.WhatsAppID, bool) { +func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) { userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$", - bridge.Config.Bridge.FormatUsername("(.+)", "([0-9]+)"), + bridge.Config.Bridge.FormatUsername("([0-9]+)"), bridge.Config.Homeserver.Domain)) if err != nil { bridge.Log.Warnln("Failed to compile puppet user ID regex:", err) - return "", "", false + return "", false } match := userIDRegex.FindStringSubmatch(string(mxid)) - if match == nil || len(match) != 3 { - return "", "", false + if match == nil || len(match) != 2 { + return "", false } - receiver := types.MatrixUserID(match[1]) - receiver = strings.Replace(receiver, "=40", "@", 1) - colonIndex := strings.LastIndex(receiver, "=3") - receiver = receiver[:colonIndex] + ":" + receiver[colonIndex+len("=3"):] jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix) - return receiver, jid, true + return jid, true } func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet { - receiver, jid, ok := bridge.ParsePuppetMXID(mxid) + jid, ok := bridge.ParsePuppetMXID(mxid) if !ok { return nil } - user := bridge.GetUser(receiver) - if user == nil { - return nil - } - - return user.GetPuppetByJID(jid) + return bridge.GetPuppetByJID(jid) } -func (user *User) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet { - receiver, jid, ok := user.bridge.ParsePuppetMXID(mxid) - if !ok || receiver != user.ID { - return nil - } - - return user.GetPuppetByJID(jid) -} - -func (user *User) GetPuppetByJID(jid types.WhatsAppID) *Puppet { - user.puppetsLock.Lock() - defer user.puppetsLock.Unlock() - puppet, ok := user.puppets[jid] +func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet { + bridge.puppetsLock.Lock() + defer bridge.puppetsLock.Unlock() + puppet, ok := bridge.puppets[jid] if !ok { - dbPuppet := user.bridge.DB.Puppet.Get(jid, user.ID) + dbPuppet := bridge.DB.Puppet.Get(jid) if dbPuppet == nil { - dbPuppet = user.bridge.DB.Puppet.New() + dbPuppet = bridge.DB.Puppet.New() dbPuppet.JID = jid - dbPuppet.Receiver = user.ID dbPuppet.Insert() } - puppet = user.NewPuppet(dbPuppet) - user.puppets[puppet.JID] = puppet + puppet = bridge.NewPuppet(dbPuppet) + bridge.puppets[puppet.JID] = puppet } return puppet } -func (user *User) GetAllPuppets() []*Puppet { - user.puppetsLock.Lock() - defer user.puppetsLock.Unlock() - dbPuppets := user.bridge.DB.Puppet.GetAll(user.ID) +func (bridge *Bridge) GetAllPuppets() []*Puppet { + bridge.puppetsLock.Lock() + defer bridge.puppetsLock.Unlock() + dbPuppets := bridge.DB.Puppet.GetAll() output := make([]*Puppet, len(dbPuppets)) for index, dbPuppet := range dbPuppets { - puppet, ok := user.puppets[dbPuppet.JID] + puppet, ok := bridge.puppets[dbPuppet.JID] if !ok { - puppet = user.NewPuppet(dbPuppet) - user.puppets[dbPuppet.JID] = puppet + puppet = bridge.NewPuppet(dbPuppet) + bridge.puppets[dbPuppet.JID] = puppet } output[index] = puppet } return output } -func (user *User) NewPuppet(dbPuppet *database.Puppet) *Puppet { +func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { return &Puppet{ Puppet: dbPuppet, - user: user, - bridge: user.bridge, - log: user.log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), + bridge: bridge, + log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), MXID: fmt.Sprintf("@%s:%s", - user.bridge.Config.Bridge.FormatUsername( - dbPuppet.Receiver, + bridge.Config.Bridge.FormatUsername( strings.Replace( dbPuppet.JID, whatsappExt.NewUserSuffix, "", 1)), - user.bridge.Config.Homeserver.Domain), + bridge.Config.Homeserver.Domain), } } type Puppet struct { *database.Puppet - user *User bridge *Bridge log log.Logger @@ -143,13 +121,13 @@ func (puppet *Puppet) PhoneNumber() string { } func (puppet *Puppet) Intent() *appservice.IntentAPI { - return puppet.bridge.AppService.Intent(puppet.MXID) + return puppet.bridge.AS.Intent(puppet.MXID) } -func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { +func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicInfo) bool { if avatar == nil { var err error - avatar, err = puppet.user.Conn.GetProfilePicThumb(puppet.JID) + avatar, err = source.Conn.GetProfilePicThumb(puppet.JID) if err != nil { puppet.log.Errorln(err) return false @@ -184,11 +162,11 @@ func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { return true } -func (puppet *Puppet) Sync(contact whatsapp.Contact) { +func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) { puppet.Intent().EnsureRegistered() - if contact.Jid == puppet.user.JID() { - contact.Notify = puppet.user.Conn.Info.Pushname + if contact.Jid == source.JID { + contact.Notify = source.Conn.Info.Pushname } newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact) if puppet.Displayname != newName { @@ -201,7 +179,7 @@ func (puppet *Puppet) Sync(contact whatsapp.Contact) { } } - if puppet.UpdateAvatar(nil) { + if puppet.UpdateAvatar(source, nil) { puppet.Update() } } diff --git a/user.go b/user.go index 385e8ba..2155752 100644 --- a/user.go +++ b/user.go @@ -17,14 +17,11 @@ package main import ( - "regexp" "strings" - "sync" "time" "github.com/Rhymen/go-whatsapp" "github.com/skip2/go-qrcode" - "maunium.net/go/gomatrix/format" log "maunium.net/go/maulogger" "maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/types" @@ -41,31 +38,42 @@ type User struct { Admin bool Whitelisted bool jid string - - portalsByMXID map[types.MatrixRoomID]*Portal - portalsByJID map[types.WhatsAppID]*Portal - portalsLock sync.Mutex - puppets map[types.WhatsAppID]*Puppet - puppetsLock sync.Mutex - - htmlParser *format.HTMLParser - - waReplString map[*regexp.Regexp]string - waReplFunc map[*regexp.Regexp]func(string) string - waReplFuncText map[*regexp.Regexp]func(string) string } -func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User { - user, ok := bridge.users[userID] +func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User { + bridge.usersLock.Lock() + defer bridge.usersLock.Unlock() + user, ok := bridge.usersByMXID[userID] if !ok { - dbUser := bridge.DB.User.Get(userID) + dbUser := bridge.DB.User.GetByMXID(userID) if dbUser == nil { dbUser = bridge.DB.User.New() - dbUser.ID = userID + dbUser.MXID = userID dbUser.Insert() } user = bridge.NewUser(dbUser) - bridge.users[user.ID] = user + bridge.usersByMXID[user.MXID] = user + if len(user.ManagementRoom) > 0 { + bridge.managementRooms[user.ManagementRoom] = user + } + } + return user +} + + +func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User { + bridge.usersLock.Lock() + defer bridge.usersLock.Unlock() + user, ok := bridge.usersByJID[userID] + if !ok { + dbUser := bridge.DB.User.GetByMXID(userID) + if dbUser == nil { + dbUser = bridge.DB.User.New() + dbUser.MXID = userID + dbUser.Insert() + } + user = bridge.NewUser(dbUser) + bridge.usersByJID[user.JID] = user if len(user.ManagementRoom) > 0 { bridge.managementRooms[user.ManagementRoom] = user } @@ -74,13 +82,15 @@ func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User { } func (bridge *Bridge) GetAllUsers() []*User { + bridge.usersLock.Lock() + defer bridge.usersLock.Unlock() dbUsers := bridge.DB.User.GetAll() output := make([]*User, len(dbUsers)) for index, dbUser := range dbUsers { - user, ok := bridge.users[dbUser.ID] + user, ok := bridge.usersByMXID[dbUser.MXID] if !ok { user = bridge.NewUser(dbUser) - bridge.users[user.ID] = user + bridge.usersByMXID[user.MXID] = user if len(user.ManagementRoom) > 0 { bridge.managementRooms[user.ManagementRoom] = user } @@ -94,15 +104,10 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User { user := &User{ User: dbUser, bridge: bridge, - log: bridge.Log.Sub("User").Sub(string(dbUser.ID)), - portalsByMXID: make(map[types.MatrixRoomID]*Portal), - portalsByJID: make(map[types.WhatsAppID]*Portal), - puppets: make(map[types.WhatsAppID]*Puppet), + log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)), } - user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.ID) - user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.ID) - user.htmlParser = user.newHTMLParser() - user.waReplString, user.waReplFunc, user.waReplFuncText = user.newWhatsAppFormatMaps() + user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID) + user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID) return user } @@ -152,7 +157,6 @@ func (user *User) RestoreSession() bool { sess, err := user.Conn.RestoreSession(*user.Session) if err != nil { user.log.Errorln("Failed to restore session:", err) - //user.SetSession(nil) return false } user.SetSession(&sess) @@ -162,8 +166,12 @@ func (user *User) RestoreSession() bool { return false } +func (user *User) IsLoggedIn() bool { + return user.Conn != nil +} + func (user *User) Login(roomID types.MatrixRoomID) { - bot := user.bridge.AppService.BotClient() + bot := user.bridge.AS.BotClient() qrChan := make(chan string, 2) go func() { @@ -194,38 +202,24 @@ func (user *User) Login(roomID types.MatrixRoomID) { qrChan <- "error" return } + user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1) user.Session = &session user.Update() bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...") go user.Sync() } -func (user *User) JID() string { - if user.Conn == nil { - return "" - } - if len(user.jid) == 0 { - user.jid = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1) - } - return user.jid -} - func (user *User) Sync() { user.log.Debugln("Syncing...") user.Conn.Contacts() for jid, contact := range user.Conn.Store.Contacts { if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) { - puppet := user.GetPuppetByJID(contact.Jid) - puppet.Sync(contact) + puppet := user.bridge.GetPuppetByJID(contact.Jid) + puppet.Sync(user, contact) + } else { + portal := user.bridge.GetPortalByJID(database.GroupPortalKey(contact.Jid)) + portal.Sync(user, contact) } - - if len(contact.Notify) == 0 && !strings.HasSuffix(jid, "@g.us") { - // No messages sent -> don't bridge - continue - } - - portal := user.GetPortalByJID(contact.Jid) - portal.Sync(contact) } } @@ -237,33 +231,41 @@ func (user *User) HandleJSONParseError(err error) { user.log.Errorln("WhatsApp JSON parse error:", err) } +func (user *User) PortalKey(jid types.WhatsAppID) database.PortalKey { + return database.NewPortalKey(jid, user.JID) +} + +func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal { + return user.bridge.GetPortalByJID(user.PortalKey(jid)) +} + func (user *User) HandleTextMessage(message whatsapp.TextMessage) { portal := user.GetPortalByJID(message.Info.RemoteJid) - portal.HandleTextMessage(message) + portal.HandleTextMessage(user, message) } func (user *User) HandleImageMessage(message whatsapp.ImageMessage) { portal := user.GetPortalByJID(message.Info.RemoteJid) - portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) + portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) } func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) { portal := user.GetPortalByJID(message.Info.RemoteJid) - portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) + portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) } func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) { portal := user.GetPortalByJID(message.Info.RemoteJid) - portal.HandleMediaMessage(message.Download, nil, message.Info, message.Type, "") + portal.HandleMediaMessage(user, message.Download, nil, message.Info, message.Type, "") } func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) { portal := user.GetPortalByJID(message.Info.RemoteJid) - portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Title) + portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Title) } func (user *User) HandlePresence(info whatsappExt.Presence) { - puppet := user.GetPuppetByJID(info.SenderJID) + puppet := user.bridge.GetPuppetByJID(info.SenderJID) switch info.Status { case whatsappExt.PresenceUnavailable: puppet.Intent().SetPresence("offline") @@ -277,6 +279,12 @@ func (user *User) HandlePresence(info whatsappExt.Presence) { } case whatsappExt.PresenceComposing: portal := user.GetPortalByJID(info.JID) + if len(puppet.typingIn) > 0 && puppet.typingAt+15 > time.Now().Unix() { + if puppet.typingIn == portal.MXID { + return + } + puppet.Intent().UserTyping(puppet.typingIn, false, 0) + } puppet.typingIn = portal.MXID puppet.typingAt = time.Now().Unix() puppet.Intent().UserTyping(portal.MXID, true, 15*1000) @@ -290,9 +298,9 @@ func (user *User) HandleMsgInfo(info whatsappExt.MsgInfo) { return } - intent := user.GetPuppetByJID(info.SenderJID).Intent() + intent := user.bridge.GetPuppetByJID(info.SenderJID).Intent() for _, id := range info.IDs { - msg := user.bridge.DB.Message.GetByJID(user.ID, id) + msg := user.bridge.DB.Message.GetByJID(portal.Key, id) if msg == nil { continue } @@ -308,11 +316,11 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) { switch cmd.Type { case whatsappExt.CommandPicture: if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) { - puppet := user.GetPuppetByJID(cmd.JID) - puppet.UpdateAvatar(cmd.ProfilePicInfo) + puppet := user.bridge.GetPuppetByJID(cmd.JID) + puppet.UpdateAvatar(user, cmd.ProfilePicInfo) } else { portal := user.GetPortalByJID(cmd.JID) - portal.UpdateAvatar(cmd.ProfilePicInfo) + portal.UpdateAvatar(user, cmd.ProfilePicInfo) } } } diff --git a/vendor/maunium.net/go/gomatrix/client.go b/vendor/maunium.net/go/gomatrix/client.go index 14549ba..0806138 100644 --- a/vendor/maunium.net/go/gomatrix/client.go +++ b/vendor/maunium.net/go/gomatrix/client.go @@ -463,7 +463,7 @@ func (cli *Client) SetAvatarURL(url string) (err error) { // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) { txnID := txnID() - urlPath := cli.BuildURL("rooms", roomID, "send", string(eventType), txnID) + urlPath := cli.BuildURL("rooms", roomID, "send", eventType.String(), txnID) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) return } @@ -472,7 +472,7 @@ func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJ // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { txnID := txnID() - urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", string(eventType), txnID}, map[string]string{ + urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", eventType.String(), txnID}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) @@ -482,7 +482,7 @@ func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey) + urlPath := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) return } @@ -490,7 +490,7 @@ func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey s // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", string(eventType), stateKey}, map[string]string{ + urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) @@ -500,7 +500,7 @@ func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, st // SendText sends an m.room.message event into the given room with a msgtype of m.text // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, "m.room.message", Content{ + return cli.SendMessageEvent(roomID, EventMessage, Content{ MsgType: MsgText, Body: text, }) @@ -509,7 +509,7 @@ func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) { // SendImage sends an m.room.message event into the given room with a msgtype of m.image // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, "m.room.message", Content{ + return cli.SendMessageEvent(roomID, EventMessage, Content{ MsgType: MsgImage, Body: body, URL: url, @@ -519,7 +519,7 @@ func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) { // SendVideo sends an m.room.message event into the given room with a msgtype of m.video // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, "m.room.message", Content{ + return cli.SendMessageEvent(roomID, EventMessage, Content{ MsgType: MsgVideo, Body: body, URL: url, @@ -529,7 +529,7 @@ func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) { // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, "m.room.message", Content{ + return cli.SendMessageEvent(roomID, EventMessage, Content{ MsgType: MsgNotice, Body: text, }) @@ -622,7 +622,7 @@ func (cli *Client) SetPresence(status string) (err error) { // the HTTP response body, or return an error. // See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) { - u := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey) + u := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest("GET", u, nil, outContent) return } diff --git a/vendor/maunium.net/go/gomatrix/events.go b/vendor/maunium.net/go/gomatrix/events.go index fcd0538..9b4ef3e 100644 --- a/vendor/maunium.net/go/gomatrix/events.go +++ b/vendor/maunium.net/go/gomatrix/events.go @@ -5,28 +5,44 @@ import ( "sync" ) -type EventType string +type EventType struct { + Type string + IsState bool +} + +func (et *EventType) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &et.Type) +} + +func (et *EventType) MarshalJSON() ([]byte, error) { + return json.Marshal(&et.Type) +} + +func (et *EventType) String() string { + return et.Type +} + type MessageType string // State events -const ( - StateAliases EventType = "m.room.aliases" - StateCanonicalAlias = "m.room.canonical_alias" - StateCreate = "m.room.create" - StateJoinRules = "m.room.join_rules" - StateMember = "m.room.member" - StatePowerLevels = "m.room.power_levels" - StateRoomName = "m.room.name" - StateTopic = "m.room.topic" - StateRoomAvatar = "m.room.avatar" - StatePinnedEvents = "m.room.pinned_events" +var ( + StateAliases = EventType{"m.room.aliases", true} + StateCanonicalAlias = EventType{"m.room.canonical_alias", true} + StateCreate = EventType{"m.room.create", true} + StateJoinRules = EventType{"m.room.join_rules", true} + StateMember = EventType{"m.room.member", true} + StatePowerLevels = EventType{"m.room.power_levels", true} + StateRoomName = EventType{"m.room.name", true} + StateTopic = EventType{"m.room.topic", true} + StateRoomAvatar = EventType{"m.room.avatar", true} + StatePinnedEvents = EventType{"m.room.pinned_events", true} ) // Message events -const ( - EventRedaction EventType = "m.room.redaction" - EventMessage = "m.room.message" - EventSticker = "m.sticker" +var ( + EventRedaction = EventType{"m.room.redaction", false} + EventMessage = EventType{"m.room.message", false} + EventSticker = EventType{"m.sticker", false} ) // Msgtypes @@ -258,12 +274,12 @@ func (pl *PowerLevels) EnsureUserLevel(userID string, level int) bool { return false } -func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int { +func (pl *PowerLevels) GetEventLevel(eventType EventType) int { pl.eventsLock.RLock() defer pl.eventsLock.RUnlock() level, ok := pl.Events[eventType] if !ok { - if isState { + if eventType.IsState { return pl.StateDefault() } return pl.EventsDefault @@ -271,20 +287,20 @@ func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int { return level } -func (pl *PowerLevels) SetEventLevel(eventType EventType, isState bool, level int) { +func (pl *PowerLevels) SetEventLevel(eventType EventType, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() - if (isState && level == pl.StateDefault()) || (!isState && level == pl.EventsDefault) { + if (eventType.IsState && level == pl.StateDefault()) || (!eventType.IsState && level == pl.EventsDefault) { delete(pl.Events, eventType) } else { pl.Events[eventType] = level } } -func (pl *PowerLevels) EnsureEventLevel(eventType EventType, isState bool, level int) bool { - existingLevel := pl.GetEventLevel(eventType, isState) +func (pl *PowerLevels) EnsureEventLevel(eventType EventType, level int) bool { + existingLevel := pl.GetEventLevel(eventType) if existingLevel != level { - pl.SetEventLevel(eventType, isState, level) + pl.SetEventLevel(eventType, level) return true } return false diff --git a/vendor/maunium.net/go/mautrix-appservice/appservice.go b/vendor/maunium.net/go/mautrix-appservice/appservice.go index 1ae68a2..e089b32 100644 --- a/vendor/maunium.net/go/mautrix-appservice/appservice.go +++ b/vendor/maunium.net/go/mautrix-appservice/appservice.go @@ -2,17 +2,19 @@ package appservice import ( "fmt" + "html/template" "io/ioutil" "os" + "path/filepath" "gopkg.in/yaml.v2" - "maunium.net/go/maulogger" - "strings" - "net/http" "errors" "maunium.net/go/gomatrix" + "maunium.net/go/maulogger" + "net/http" "regexp" + "strings" ) // EventChannelSize is the size for the Events channel in Appservice instances. @@ -263,15 +265,24 @@ func CreateLogConfig() LogConfig { } } +type FileFormatData struct { + Date string + Index int +} + // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct. func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat { - path := lc.FileNameFormat - if len(lc.Directory) > 0 { - path = lc.Directory + "/" + path - } + os.MkdirAll(lc.Directory, 0700) + path := filepath.Join(lc.Directory, lc.FileNameFormat) + tpl, _ := template.New("fileformat").Parse(path) return func(now string, i int) string { - return fmt.Sprintf(path, now, i) + var buf strings.Builder + tpl.Execute(&buf, FileFormatData{ + Date: now, + Index: i, + }) + return buf.String() } } diff --git a/vendor/maunium.net/go/mautrix-appservice/intent.go b/vendor/maunium.net/go/mautrix-appservice/intent.go index 957d529..3aa864b 100644 --- a/vendor/maunium.net/go/mautrix-appservice/intent.go +++ b/vendor/maunium.net/go/mautrix-appservice/intent.go @@ -201,19 +201,19 @@ func (intent *IntentAPI) RedactEvent(roomID, eventID string, req *gomatrix.ReqRe } func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, "m.room.name", "", map[string]interface{}{ + return intent.SendStateEvent(roomID, gomatrix.StateRoomName, "", map[string]interface{}{ "name": roomName, }) } func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, "m.room.avatar", "", map[string]interface{}{ + return intent.SendStateEvent(roomID, gomatrix.StateRoomAvatar, "", map[string]interface{}{ "url": avatarURL, }) } func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, "m.room.topic", "", map[string]interface{}{ + return intent.SendStateEvent(roomID, gomatrix.StateTopic, "", map[string]interface{}{ "topic": topic, }) } diff --git a/vendor/maunium.net/go/mautrix-appservice/statestore.go b/vendor/maunium.net/go/mautrix-appservice/statestore.go index 944e865..14f9b74 100644 --- a/vendor/maunium.net/go/mautrix-appservice/statestore.go +++ b/vendor/maunium.net/go/mautrix-appservice/statestore.go @@ -15,13 +15,15 @@ type StateStore interface { SetTyping(roomID, userID string, timeout int64) IsInRoom(roomID, userID string) bool + IsInvited(roomID, userID string) bool + IsMembership(roomID, userID string, allowedMemberships ...string) bool SetMembership(roomID, userID, membership string) SetPowerLevels(roomID string, levels *gomatrix.PowerLevels) GetPowerLevels(roomID string) *gomatrix.PowerLevels GetPowerLevel(roomID, userID string) int - GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int - HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool + GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int + HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool } func (as *AppService) UpdateState(evt *gomatrix.Event) { @@ -126,7 +128,21 @@ func (store *BasicStateStore) GetMembership(roomID, userID string) string { } func (store *BasicStateStore) IsInRoom(roomID, userID string) bool { - return store.GetMembership(roomID, userID) == "join" + return store.IsMembership(roomID, userID, "join") +} + +func (store *BasicStateStore) IsInvited(roomID, userID string) bool { + return store.IsMembership(roomID, userID, "join", "invite") +} + +func (store *BasicStateStore) IsMembership(roomID, userID string, allowedMemberships ...string) bool { + membership := store.GetMembership(roomID, userID) + for _, allowedMembership := range allowedMemberships { + if allowedMembership == membership { + return true + } + } + return false } func (store *BasicStateStore) SetMembership(roomID, userID, membership string) { @@ -160,19 +176,10 @@ func (store *BasicStateStore) GetPowerLevel(roomID, userID string) int { return store.GetPowerLevels(roomID).GetUserLevel(userID) } -func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int { - levels := store.GetPowerLevels(roomID) - switch eventType { - case "kick": - return levels.Kick() - case "invite": - return levels.Invite() - case "redact": - return levels.Redact() - } - return levels.GetEventLevel(eventType, isState) +func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int { + return store.GetPowerLevels(roomID).GetEventLevel(eventType) } -func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool { - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType, isState) +func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool { + return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) }