Initial desegregation of users and automatic config updating

This commit is contained in:
Tulir Asokan 2018-08-29 00:40:54 +03:00
parent 55c3ab2d4f
commit c7348f29b0
24 changed files with 806 additions and 475 deletions

1
.gitignore vendored
View file

@ -6,3 +6,4 @@
*.session *.session
*.json *.json
*.db *.db
*.log

4
Gopkg.lock generated
View file

@ -123,7 +123,7 @@
".", ".",
"format" "format"
] ]
revision = "ead1f970c8f56d1854cb9eb4a54c03aa6dafd753" revision = "42a3133c4980e4b1ea5fb52329d977f592d67cf0"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -141,7 +141,7 @@
branch = "master" branch = "master"
name = "maunium.net/go/mautrix-appservice" name = "maunium.net/go/mautrix-appservice"
packages = ["."] packages = ["."]
revision = "269f2ab602126a2de94bc86a457392426cce1ab2" revision = "37d4449056015cea5d0a4420bba595c61ad32007"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"

View file

@ -48,7 +48,7 @@ type CommandEvent struct {
func (ce *CommandEvent) Reply(msg string) { func (ce *CommandEvent) Reply(msg string) {
_, err := ce.Bot.SendNotice(string(ce.RoomID), msg) _, err := ce.Bot.SendNotice(string(ce.RoomID), msg)
if err != nil { if err != nil {
ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.ID, err) ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err)
} }
} }
@ -56,7 +56,7 @@ func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, mes
args := strings.Split(message, " ") args := strings.Split(message, " ")
cmd := strings.ToLower(args[0]) cmd := strings.ToLower(args[0])
ce := &CommandEvent{ ce := &CommandEvent{
Bot: handler.bridge.AppService.BotIntent(), Bot: handler.bridge.AS.BotIntent(),
Bridge: handler.bridge, Bridge: handler.bridge,
Handler: handler, Handler: handler,
RoomID: roomID, RoomID: roomID,

View file

@ -56,12 +56,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return err return err
} }
type DisplaynameTemplateArgs struct {
Displayname string
}
type UsernameTemplateArgs struct { type UsernameTemplateArgs struct {
Receiver string
UserID string UserID string
} }
@ -74,14 +69,9 @@ func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) string {
return buf.String() return buf.String()
} }
func (bc BridgeConfig) FormatUsername(receiver types.MatrixUserID, userID types.WhatsAppID) string { func (bc BridgeConfig) FormatUsername(userID types.WhatsAppID) string {
var buf bytes.Buffer var buf bytes.Buffer
receiver = strings.Replace(receiver, "@", "=40", 1) bc.usernameTemplate.Execute(&buf, userID)
receiver = strings.Replace(receiver, ":", "=3", 1)
bc.usernameTemplate.Execute(&buf, UsernameTemplateArgs{
Receiver: receiver,
UserID: userID,
})
return buf.String() return buf.String()
} }
@ -92,7 +82,7 @@ func (bc BridgeConfig) MarshalYAML() (interface{}, error) {
Name: "{{.Name}}", Name: "{{.Name}}",
Short: "{{.Short}}", Short: "{{.Short}}",
}) })
bc.UsernameTemplate = bc.FormatUsername("{{.Receiver}}", "{{.UserID}}") bc.UsernameTemplate = bc.FormatUsername("{{.}}")
return bc, nil return bc, nil
} }

View file

@ -78,7 +78,6 @@ func (config *Config) Save(path string) error {
func (config *Config) MakeAppService() (*appservice.AppService, error) { func (config *Config) MakeAppService() (*appservice.AppService, error) {
as := appservice.Create() as := appservice.Create()
as.LogConfig = config.Logging
as.HomeserverDomain = config.Homeserver.Domain as.HomeserverDomain = config.Homeserver.Domain
as.HomeserverURL = config.Homeserver.Address as.HomeserverURL = config.Homeserver.Address
as.Host.Hostname = config.AppService.Hostname as.Host.Hostname = config.AppService.Hostname

115
config/recursivemap.go Normal file
View file

@ -0,0 +1,115 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"strings"
)
type RecursiveMap map[interface{}]interface{}
func (rm RecursiveMap) GetDefault(path string, defVal interface{}) interface{} {
val, ok := rm.Get(path)
if !ok {
return defVal
}
return val
}
func (rm RecursiveMap) GetMap(path string) RecursiveMap {
val := rm.GetDefault(path, nil)
if val == nil {
return nil
}
newRM, ok := val.(map[interface{}]interface{})
if ok {
return RecursiveMap(newRM)
}
return nil
}
func (rm RecursiveMap) Get(path string) (interface{}, bool) {
if index := strings.IndexRune(path, '.'); index >= 0 {
key := path[:index]
path = path[index+1:]
submap := rm.GetMap(key)
if submap == nil {
return nil, false
}
return submap.Get(path)
}
val, ok := rm[path]
return val, ok
}
func (rm RecursiveMap) GetIntDefault(path string, defVal int) int {
val, ok := rm.GetInt(path)
if !ok {
return defVal
}
return val
}
func (rm RecursiveMap) GetInt(path string) (int, bool) {
val, ok := rm.Get(path)
if !ok {
return 0, ok
}
intVal, ok := val.(int)
return intVal, ok
}
func (rm RecursiveMap) GetStringDefault(path string, defVal string) string {
val, ok := rm.GetString(path)
if !ok {
return defVal
}
return val
}
func (rm RecursiveMap) GetString(path string) (string, bool) {
val, ok := rm.Get(path)
if !ok {
return "", ok
}
strVal, ok := val.(string)
return strVal, ok
}
func (rm RecursiveMap) Set(path string, value interface{}) {
if index := strings.IndexRune(path, '.'); index >= 0 {
key := path[:index]
path = path[index+1:]
nextRM := rm.GetMap(key)
if nextRM == nil {
nextRM = make(RecursiveMap)
rm[key] = nextRM
}
nextRM.Set(path, value)
return
}
rm[path] = value
}
func (rm RecursiveMap) CopyFrom(otherRM RecursiveMap, path string) {
val, ok := otherRM.Get(path)
if ok {
rm.Set(path, val)
}
}

View file

@ -56,7 +56,7 @@ func (config *Config) copyToRegistration(registration *appservice.Registration)
registration.SenderLocalpart = config.AppService.Bot.Username registration.SenderLocalpart = config.AppService.Bot.Username
userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$", userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
config.Bridge.FormatUsername(".+", "[0-9]+"), config.Bridge.FormatUsername("[0-9]+"),
config.Homeserver.Domain)) config.Homeserver.Domain))
if err != nil { if err != nil {
return err return err

100
config/update.go Normal file
View file

@ -0,0 +1,100 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2018 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"io/ioutil"
"gopkg.in/yaml.v2"
)
func Update(path, basePath string) error {
oldCfgData, err := ioutil.ReadFile(path)
if err != nil {
return err
}
oldCfg := make(RecursiveMap)
err = yaml.Unmarshal(oldCfgData, &oldCfg)
if err != nil {
return err
}
baseCfgData, err := ioutil.ReadFile(basePath)
if err != nil {
return err
}
baseCfg := make(RecursiveMap)
err = yaml.Unmarshal(baseCfgData, &baseCfg)
if err != nil {
return err
}
err = runUpdate(oldCfg, baseCfg)
if err != nil {
return err
}
newCfgData, err := yaml.Marshal(&baseCfg)
if err != nil {
return err
}
return ioutil.WriteFile(path, newCfgData, 0600)
}
func runUpdate(oldCfg, newCfg RecursiveMap) error {
cp := func(path string) {
newCfg.CopyFrom(oldCfg, path)
}
cp("homeserver.address")
cp("homeserver.domain")
cp("appservice.address")
cp("appservice.hostname")
cp("appservice.port")
cp("appservice.database.type")
cp("appservice.database.uri")
cp("appservice.state_store_path")
cp("appservice.id")
cp("appservice.bot.username")
cp("appservice.bot.displayname")
cp("appservice.bot.avatar")
cp("appservice.bot.as_token")
cp("appservice.bot.hs_token")
cp("bridge.username_template")
cp("bridge.displayname_template")
cp("bridge.command_prefix")
cp("bridge.permissions")
cp("logging.directory")
cp("logging.file_name_format")
cp("logging.file_date_format")
cp("logging.file_mode")
cp("logging.timestamp_format")
cp("logging.print_level")
return nil
}

View file

@ -30,12 +30,13 @@ type MessageQuery struct {
func (mq *MessageQuery) CreateTable() error { func (mq *MessageQuery) CreateTable() error {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message ( _, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
owner VARCHAR(255), chat_jid VARCHAR(25) NOT NULL,
jid VARCHAR(255), chat_receiver VARCHAR(25) NOT NULL,
mxid VARCHAR(255) NOT NULL UNIQUE, jid VARCHAR(255) NOT NULL,
mxid VARCHAR(255) NOT NULL UNIQUE,
PRIMARY KEY (owner, jid), PRIMARY KEY (chat_jid, jid),
FOREIGN KEY (owner) REFERENCES user(mxid) FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`) )`)
return err return err
} }
@ -47,8 +48,8 @@ func (mq *MessageQuery) New() *Message {
} }
} }
func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) { func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query("SELECT * FROM message WHERE owner=?", owner) rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver)
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -59,8 +60,8 @@ func (mq *MessageQuery) GetAll(owner types.MatrixUserID) (messages []*Message) {
return return
} }
func (mq *MessageQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppMessageID) *Message { func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
return mq.get("SELECT * FROM message WHERE owner=? AND jid=?", owner, jid) return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid)
} }
func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message { func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
@ -79,13 +80,13 @@ type Message struct {
db *Database db *Database
log log.Logger log log.Logger
Owner types.MatrixUserID Chat PortalKey
JID types.WhatsAppMessageID JID types.WhatsAppMessageID
MXID types.MatrixEventID MXID types.MatrixEventID
} }
func (msg *Message) Scan(row Scannable) *Message { func (msg *Message) Scan(row Scannable) *Message {
err := row.Scan(&msg.Owner, &msg.JID, &msg.MXID) err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
msg.log.Errorln("Database scan failed:", err) msg.log.Errorln("Database scan failed:", err)
@ -96,17 +97,17 @@ func (msg *Message) Scan(row Scannable) *Message {
} }
func (msg *Message) Insert() error { func (msg *Message) Insert() error {
_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Owner, msg.JID, msg.MXID) _, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID)
if err != nil { if err != nil {
msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err) msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
} }
return err return err
} }
func (msg *Message) Update() error { func (msg *Message) Update() error {
_, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE owner=? AND jid=?", msg.MXID, msg.Owner, msg.JID) _, err := msg.db.Exec("UPDATE portal SET mxid=? WHERE chat_jid=? AND chat_receiver=? AND jid=?", msg.MXID, msg.Chat.JID, msg.Chat.Receiver, msg.JID)
if err != nil { if err != nil {
msg.log.Warnfln("Failed to update %s->%s: %v", msg.Owner, msg.JID, err) msg.log.Warnfln("Failed to update %s: %v", msg.Chat, msg.JID, err)
} }
return err return err
} }

View file

@ -18,11 +18,41 @@ package database
import ( import (
"database/sql" "database/sql"
"strings"
log "maunium.net/go/maulogger" log "maunium.net/go/maulogger"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
) )
type PortalKey struct {
JID types.WhatsAppID
Receiver types.WhatsAppID
}
func GroupPortalKey(jid types.WhatsAppID) PortalKey {
return PortalKey{
JID: jid,
Receiver: jid,
}
}
func NewPortalKey(jid, receiver types.WhatsAppID) PortalKey {
if strings.HasSuffix(jid, "@g.us") {
receiver = jid
}
return PortalKey{
JID: jid,
Receiver: receiver,
}
}
func (key PortalKey) String() string {
if key.Receiver == key.JID {
return key.JID
}
return key.JID + "-" + key.Receiver
}
type PortalQuery struct { type PortalQuery struct {
db *Database db *Database
log log.Logger log log.Logger
@ -30,16 +60,16 @@ type PortalQuery struct {
func (pq *PortalQuery) CreateTable() error { func (pq *PortalQuery) CreateTable() error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal ( _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(255), jid VARCHAR(25),
owner VARCHAR(255), receiver VARCHAR(25),
mxid VARCHAR(255) UNIQUE, mxid VARCHAR(255) UNIQUE,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL, topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL, avatar VARCHAR(255) NOT NULL,
PRIMARY KEY (jid, owner), PRIMARY KEY (jid, receiver),
FOREIGN KEY (owner) REFERENCES user(mxid) FOREIGN KEY (receiver) REFERENCES user(mxid)
)`) )`)
return err return err
} }
@ -51,8 +81,8 @@ func (pq *PortalQuery) New() *Portal {
} }
} }
func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) { func (pq *PortalQuery) GetAll() (portals []*Portal) {
rows, err := pq.db.Query("SELECT * FROM portal WHERE owner=?", owner) rows, err := pq.db.Query("SELECT * FROM portal")
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -63,8 +93,8 @@ func (pq *PortalQuery) GetAll(owner types.MatrixUserID) (portals []*Portal) {
return return
} }
func (pq *PortalQuery) GetByJID(owner types.MatrixUserID, jid types.WhatsAppID) *Portal { func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE jid=? AND owner=?", jid, owner) return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver)
} }
func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal { func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
@ -83,9 +113,8 @@ type Portal struct {
db *Database db *Database
log log.Logger log log.Logger
JID types.WhatsAppID Key PortalKey
MXID types.MatrixRoomID MXID types.MatrixRoomID
Owner types.MatrixUserID
Name string Name string
Topic string Topic string
@ -93,7 +122,7 @@ type Portal struct {
} }
func (portal *Portal) Scan(row Scannable) *Portal { func (portal *Portal) Scan(row Scannable) *Portal {
err := row.Scan(&portal.JID, &portal.Owner, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar) err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &portal.MXID, &portal.Name, &portal.Topic, &portal.Avatar)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err) portal.log.Errorln("Database scan failed:", err)
@ -103,15 +132,18 @@ func (portal *Portal) Scan(row Scannable) *Portal {
return portal return portal
} }
func (portal *Portal) Insert() error { func (portal *Portal) mxidPtr() *string {
var mxid *string
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
mxid = &portal.MXID return &portal.MXID
} }
return nil
}
func (portal *Portal) Insert() error {
_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)", _, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)",
portal.JID, portal.Owner, mxid, portal.Name, portal.Topic, portal.Avatar) portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to insert %s->%s: %v", portal.JID, portal.Owner, err) portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
} }
return err return err
} }
@ -121,10 +153,10 @@ func (portal *Portal) Update() error {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
mxid = &portal.MXID mxid = &portal.MXID
} }
_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND owner=?", _, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.JID, portal.Owner) mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to update %s->%s: %v", portal.JID, portal.Owner, err) portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
} }
return err return err
} }

View file

@ -30,13 +30,9 @@ type PuppetQuery struct {
func (pq *PuppetQuery) CreateTable() error { func (pq *PuppetQuery) CreateTable() error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet ( _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(255), jid VARCHAR(25) PRIMARY KEY,
receiver VARCHAR(255),
displayname VARCHAR(255), displayname VARCHAR(255),
avatar VARCHAR(255), avatar VARCHAR(255)
PRIMARY KEY(jid, receiver)
)`) )`)
return err return err
} }
@ -48,8 +44,8 @@ func (pq *PuppetQuery) New() *Puppet {
} }
} }
func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) { func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
rows, err := pq.db.Query("SELECT * FROM puppet WHERE receiver=%s") rows, err := pq.db.Query("SELECT * FROM puppet")
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -60,8 +56,8 @@ func (pq *PuppetQuery) GetAll(receiver types.MatrixUserID) (puppets []*Puppet) {
return return
} }
func (pq *PuppetQuery) Get(jid types.WhatsAppID, receiver types.MatrixUserID) *Puppet { func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=? AND receiver=?", jid, receiver) row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid)
if row == nil { if row == nil {
return nil return nil
} }
@ -72,15 +68,13 @@ type Puppet struct {
db *Database db *Database
log log.Logger log log.Logger
JID types.WhatsAppID JID types.WhatsAppID
Receiver types.MatrixUserID
Displayname string Displayname string
Avatar string Avatar string
} }
func (puppet *Puppet) Scan(row Scannable) *Puppet { func (puppet *Puppet) Scan(row Scannable) *Puppet {
err := row.Scan(&puppet.JID, &puppet.Receiver, &puppet.Displayname, &puppet.Avatar) err := row.Scan(&puppet.JID, &puppet.Displayname, &puppet.Avatar)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
puppet.log.Errorln("Database scan failed:", err) puppet.log.Errorln("Database scan failed:", err)
@ -91,20 +85,19 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
} }
func (puppet *Puppet) Insert() error { func (puppet *Puppet) Insert() error {
_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)", _, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?)",
puppet.JID, puppet.Receiver, puppet.Displayname, puppet.Avatar) puppet.JID, puppet.Displayname, puppet.Avatar)
if err != nil { if err != nil {
puppet.log.Errorfln("Failed to insert %s->%s: %v", puppet.JID, puppet.Receiver, err) puppet.log.Errorfln("Failed to insert %s: %v", puppet.JID, err)
} }
return err return err
} }
func (puppet *Puppet) Update() error { func (puppet *Puppet) Update() error {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=? AND receiver=?", _, err := puppet.db.Exec("UPDATE puppet SET displayname=?, avatar=? WHERE jid=?",
puppet.Displayname, puppet.Avatar, puppet.Displayname, puppet.Avatar, puppet.JID)
puppet.JID, puppet.Receiver)
if err != nil { if err != nil {
puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, puppet.Receiver, err) puppet.log.Errorfln("Failed to update %s->%s: %v", puppet.JID, err)
} }
return err return err
} }

View file

@ -32,6 +32,7 @@ type UserQuery struct {
func (uq *UserQuery) CreateTable() error { func (uq *UserQuery) CreateTable() error {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user ( _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user (
mxid VARCHAR(255) PRIMARY KEY, mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(25) UNIQUE,
management_room VARCHAR(255), management_room VARCHAR(255),
@ -64,7 +65,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
return return
} }
func (uq *UserQuery) Get(userID types.MatrixUserID) *User { func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID) row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID)
if row == nil { if row == nil {
return nil return nil
@ -72,18 +73,27 @@ func (uq *UserQuery) Get(userID types.MatrixUserID) *User {
return uq.New().Scan(row) return uq.New().Scan(row)
} }
func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
type User struct { type User struct {
db *Database db *Database
log log.Logger log log.Logger
ID types.MatrixUserID MXID types.MatrixUserID
JID types.WhatsAppID
ManagementRoom types.MatrixRoomID ManagementRoom types.MatrixRoomID
Session *whatsapp.Session Session *whatsapp.Session
} }
func (user *User) Scan(row Scannable) *User { func (user *User) Scan(row Scannable) *User {
sess := whatsapp.Session{} sess := whatsapp.Session{}
err := row.Scan(&user.ID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken, err := row.Scan(&user.MXID, &user.JID, &user.ManagementRoom, &sess.ClientId, &sess.ClientToken, &sess.ServerToken,
&sess.EncKey, &sess.MacKey, &sess.Wid) &sess.EncKey, &sess.MacKey, &sess.Wid)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
@ -99,23 +109,32 @@ func (user *User) Scan(row Scannable) *User {
return user return user
} }
func (user *User) Insert() error { func (user *User) jidPtr() *string {
var sess whatsapp.Session if len(user.JID) > 0 {
return &user.JID
}
return nil
}
func (user *User) sessionUnptr() (sess whatsapp.Session) {
if user.Session != nil { if user.Session != nil {
sess = *user.Session sess = *user.Session
} }
_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.ID, user.ManagementRoom, return
}
func (user *User) Insert() error {
sess := user.sessionUnptr()
_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(), user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid) sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid)
return err return err
} }
func (user *User) Update() error { func (user *User) Update() error {
var sess whatsapp.Session sess := user.sessionUnptr()
if user.Session != nil { _, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?",
sess = *user.Session user.jidPtr(), user.ManagementRoom,
} sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid,
_, err := user.db.Exec("UPDATE user SET management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=?, wid=? WHERE mxid=?", user.MXID)
user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.Wid, user.ID)
return err return err
} }

View file

@ -21,7 +21,6 @@ appservice:
type: sqlite3 type: sqlite3
# The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
uri: mautrix-whatsapp.db uri: mautrix-whatsapp.db
# Path to the Matrix room state store. # Path to the Matrix room state store.
state_store_path: ./mx-state.json state_store_path: ./mx-state.json
@ -43,15 +42,15 @@ appservice:
# Bridge config. Currently unused. # Bridge config. Currently unused.
bridge: bridge:
# Localpart template of MXIDs for WhatsApp users. # Localpart template of MXIDs for WhatsApp users.
# {{.Receiver}} is replaced with the WhatsApp user ID of the Matrix user receiving messages. # {{.}} is replaced with the phone number of the WhatsApp user.
# {{.UserID}} is replaced with the user ID of the WhatsApp user. username_template: whatsapp_{{.}}
username_template: "whatsapp_{{.Receiver}}_{{.UserID}}"
# Displayname template for WhatsApp users. # Displayname template for WhatsApp users.
# {{.Name}} - display name # {{.Notify}} - nickname set by the WhatsApp user
# {{.Short}} - short display name (usually first name)
# {{.Notify}} - nickname (maybe set by the target WhatsApp user)
# {{.Jid}} - phone number (international format) # {{.Jid}} - phone number (international format)
displayname_template: "{{if .Name}}{{.Name}}{{else if .Notify}}{{.Notify}}{{else if .Short}}{{.Short}}{{else}}{{.Jid}}{{end}}" # The following variables are also available, but will cause problems on multi-user instances:
# {{.Name}} - display name from contact list
# {{.Short}} - short display name from contact list
displayname_template: "{{if .Notify}}{{.Notify}}{{else}}{{.Jid}}{{end}} (WA)"
# The prefix for commands. Only required in non-management rooms. # The prefix for commands. Only required in non-management rooms.
command_prefix: "!wa" command_prefix: "!wa"
@ -72,8 +71,8 @@ bridge:
logging: logging:
# The directory for log files. Will be created if not found. # The directory for log files. Will be created if not found.
directory: ./logs directory: ./logs
# Available variables: .date for the file date and .index for different log files on the same day. # Available variables: .Date for the file date and .Index for different log files on the same day.
file_name_format: "{{.date}}-{{.index}.log" file_name_format: "{{.Date}}-{{.Index}}.log"
# Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants # Date format for file names in the Go time format: https://golang.org/pkg/time/#pkg-constants
file_date_format: 2006-01-02 file_date_format: 2006-01-02
# Log file permissions. # Log file permissions.

View file

@ -18,58 +18,71 @@ package main
import ( import (
"fmt" "fmt"
"html"
"regexp" "regexp"
"strings" "strings"
"maunium.net/go/gomatrix"
"maunium.net/go/gomatrix/format" "maunium.net/go/gomatrix/format"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
func (user *User) newHTMLParser() *format.HTMLParser {
return &format.HTMLParser{
TabsToSpaces: 4,
Newline: "\n",
PillConverter: func(mxid, eventID string) string {
if mxid[0] == '@' {
puppet := user.GetPuppetByMXID(mxid)
fmt.Println(mxid, puppet)
if puppet != nil {
return "@" + puppet.PhoneNumber()
}
}
return mxid
},
BoldConverter: func(text string) string {
return fmt.Sprintf("*%s*", text)
},
ItalicConverter: func(text string) string {
return fmt.Sprintf("_%s_", text)
},
StrikethroughConverter: func(text string) string {
return fmt.Sprintf("~%s~", text)
},
MonospaceConverter: func(text string) string {
return fmt.Sprintf("```%s```", text)
},
MonospaceBlockConverter: func(text string) string {
return fmt.Sprintf("```%s```", text)
},
}
}
var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)") var italicRegex = regexp.MustCompile("([\\s>~*]|^)_(.+?)_([^a-zA-Z\\d]|$)")
var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)") var boldRegex = regexp.MustCompile("([\\s>_~]|^)\\*(.+?)\\*([^a-zA-Z\\d]|$)")
var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)") var strikethroughRegex = regexp.MustCompile("([\\s>_*]|^)~(.+?)~([^a-zA-Z\\d]|$)")
var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```") var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
var mentionRegex = regexp.MustCompile("@[0-9]+") var mentionRegex = regexp.MustCompile("@[0-9]+")
func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regexp.Regexp]func(string) string, map[*regexp.Regexp]func(string) string) { type Formatter struct {
return map[*regexp.Regexp]string{ bridge *Bridge
italicRegex: "$1<em>$2</em>$3",
boldRegex: "$1<strong>$2</strong>$3", matrixHTMLParser *format.HTMLParser
strikethroughRegex: "$1<del>$2</del>$3",
}, map[*regexp.Regexp]func(string) string{ waReplString map[*regexp.Regexp]string
waReplFunc map[*regexp.Regexp]func(string) string
waReplFuncText map[*regexp.Regexp]func(string) string
}
func NewFormatter(bridge *Bridge) *Formatter {
formatter := &Formatter{
bridge: bridge,
matrixHTMLParser: &format.HTMLParser{
TabsToSpaces: 4,
Newline: "\n",
PillConverter: func(mxid, eventID string) string {
if mxid[0] == '@' {
puppet := bridge.GetPuppetByMXID(mxid)
fmt.Println(mxid, puppet)
if puppet != nil {
return "@" + puppet.PhoneNumber()
}
}
return mxid
},
BoldConverter: func(text string) string {
return fmt.Sprintf("*%s*", text)
},
ItalicConverter: func(text string) string {
return fmt.Sprintf("_%s_", text)
},
StrikethroughConverter: func(text string) string {
return fmt.Sprintf("~%s~", text)
},
MonospaceConverter: func(text string) string {
return fmt.Sprintf("```%s```", text)
},
MonospaceBlockConverter: func(text string) string {
return fmt.Sprintf("```%s```", text)
},
},
waReplString: map[*regexp.Regexp]string{
italicRegex: "$1<em>$2</em>$3",
boldRegex: "$1<strong>$2</strong>$3",
strikethroughRegex: "$1<del>$2</del>$3",
},
}
formatter.waReplFunc = map[*regexp.Regexp]func(string) string{
codeBlockRegex: func(str string) string { codeBlockRegex: func(str string) string {
str = str[3 : len(str)-3] str = str[3 : len(str)-3]
if strings.ContainsRune(str, '\n') { if strings.ContainsRune(str, '\n') {
@ -78,18 +91,47 @@ func (user *User) newWhatsAppFormatMaps() (map[*regexp.Regexp]string, map[*regex
return fmt.Sprintf("<code>%s</code>", str) return fmt.Sprintf("<code>%s</code>", str)
}, },
mentionRegex: func(str string) string { mentionRegex: func(str string) string {
jid := str[1:] + whatsappExt.NewUserSuffix mxid, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
puppet := user.GetPuppetByJID(jid) return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, displayname)
mxid := puppet.MXID
if jid == user.JID() {
mxid = user.ID
}
return fmt.Sprintf(`<a href="https://matrix.to/#/%s">%s</a>`, mxid, puppet.Displayname)
},
}, map[*regexp.Regexp]func(string)string {
mentionRegex: func(str string) string {
puppet := user.GetPuppetByJID(str[1:] + whatsappExt.NewUserSuffix)
return puppet.Displayname
}, },
} }
formatter.waReplFuncText = map[*regexp.Regexp]func(string) string{
mentionRegex: func(str string) string {
_, displayname := formatter.getMatrixInfoByJID(str[1:] + whatsappExt.NewUserSuffix)
return displayname
},
}
return formatter
}
func (formatter *Formatter) getMatrixInfoByJID(jid string) (mxid, displayname string) {
if user := formatter.bridge.GetUserByJID(jid); user != nil {
mxid = user.MXID
displayname = user.MXID
} else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
mxid = puppet.MXID
displayname = puppet.Displayname
}
return
}
func (formatter *Formatter) ParseWhatsApp(content *gomatrix.Content) {
output := html.EscapeString(content.Body)
for regex, replacement := range formatter.waReplString {
output = regex.ReplaceAllString(output, replacement)
}
for regex, replacer := range formatter.waReplFunc {
output = regex.ReplaceAllStringFunc(output, replacer)
}
if output != content.Body {
content.FormattedBody = output
content.Format = gomatrix.FormatHTML
for regex, replacer := range formatter.waReplFuncText {
content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
}
}
}
func (formatter *Formatter) ParseMatrix(html string) string {
return formatter.matrixHTMLParser.Parse(html)
} }

77
main.go
View file

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
flag "maunium.net/go/mauflag" flag "maunium.net/go/mauflag"
@ -31,6 +32,7 @@ import (
) )
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
var baseConfigPath = flag.MakeFull("b", "base-config", "The path to the example config file.", "example-config.yaml").String()
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
var wantHelp, _ = flag.MakeHelpFlag() var wantHelp, _ = flag.MakeHelpFlag()
@ -58,29 +60,47 @@ func (bridge *Bridge) GenerateRegistration() {
} }
type Bridge struct { type Bridge struct {
AppService *appservice.AppService AS *appservice.AppService
EventProcessor *appservice.EventProcessor EventProcessor *appservice.EventProcessor
MatrixHandler *MatrixHandler MatrixHandler *MatrixHandler
Config *config.Config Config *config.Config
DB *database.Database DB *database.Database
Log log.Logger Log log.Logger
StateStore *AutosavingStateStore
Bot *appservice.IntentAPI
Formatter *Formatter
StateStore *AutosavingStateStore usersByMXID map[types.MatrixUserID]*User
usersByJID map[types.WhatsAppID]*User
users map[types.MatrixUserID]*User usersLock sync.Mutex
managementRooms map[types.MatrixRoomID]*User managementRooms map[types.MatrixRoomID]*User
managementRoomsLock sync.Mutex
portalsByMXID map[types.MatrixRoomID]*Portal
portalsByJID map[database.PortalKey]*Portal
portalsLock sync.Mutex
puppets map[types.WhatsAppID]*Puppet
puppetsLock sync.Mutex
} }
func NewBridge() *Bridge { func NewBridge() *Bridge {
bridge := &Bridge{ bridge := &Bridge{
users: make(map[types.MatrixUserID]*User), usersByMXID: make(map[types.MatrixUserID]*User),
usersByJID: make(map[types.WhatsAppID]*User),
managementRooms: make(map[types.MatrixRoomID]*User), managementRooms: make(map[types.MatrixRoomID]*User),
portalsByMXID: make(map[types.MatrixRoomID]*Portal),
portalsByJID: make(map[database.PortalKey]*Portal),
puppets: make(map[types.WhatsAppID]*Puppet),
} }
var err error err := config.Update(*configPath, *baseConfigPath)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to update config:", err)
os.Exit(10)
}
bridge.Config, err = config.Load(*configPath) bridge.Config, err = config.Load(*configPath)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "Failed to load config:", err) fmt.Fprintln(os.Stderr, "Failed to load config:", err)
os.Exit(10) os.Exit(11)
} }
return bridge return bridge
} }
@ -88,46 +108,55 @@ func NewBridge() *Bridge {
func (bridge *Bridge) Init() { func (bridge *Bridge) Init() {
var err error var err error
bridge.AppService, err = bridge.Config.MakeAppService() bridge.AS, err = bridge.Config.MakeAppService()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err) fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
os.Exit(11) os.Exit(12)
} }
bridge.AppService.Init() bridge.AS.Init()
bridge.Log = bridge.AppService.Log bridge.Bot = bridge.AS.BotIntent()
bridge.Log = log.Create()
bridge.Config.Logging.Configure(bridge.Log)
log.DefaultLogger = bridge.Log.(*log.BasicLogger) log.DefaultLogger = bridge.Log.(*log.BasicLogger)
bridge.AppService.Log = log.Sub("Matrix") err = log.OpenFile()
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
os.Exit(13)
}
bridge.AS.Log = log.Sub("Matrix")
bridge.Log.Debugln("Initializing state store") bridge.Log.Debugln("Initializing state store")
bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore) bridge.StateStore = NewAutosavingStateStore(bridge.Config.AppService.StateStore)
err = bridge.StateStore.Load() err = bridge.StateStore.Load()
if err != nil { if err != nil {
bridge.Log.Fatalln("Failed to load state store:", err) bridge.Log.Fatalln("Failed to load state store:", err)
os.Exit(12) os.Exit(14)
} }
bridge.AppService.StateStore = bridge.StateStore bridge.AS.StateStore = bridge.StateStore
bridge.Log.Debugln("Initializing database") bridge.Log.Debugln("Initializing database")
bridge.DB, err = database.New(bridge.Config.AppService.Database.URI) bridge.DB, err = database.New(bridge.Config.AppService.Database.URI)
if err != nil { if err != nil {
bridge.Log.Fatalln("Failed to initialize database:", err) bridge.Log.Fatalln("Failed to initialize database:", err)
os.Exit(13) os.Exit(15)
} }
bridge.Log.Debugln("Initializing Matrix event processor") bridge.Log.Debugln("Initializing Matrix event processor")
bridge.EventProcessor = appservice.NewEventProcessor(bridge.AppService) bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
bridge.Log.Debugln("Initializing Matrix event handler") bridge.Log.Debugln("Initializing Matrix event handler")
bridge.MatrixHandler = NewMatrixHandler(bridge) bridge.MatrixHandler = NewMatrixHandler(bridge)
bridge.Formatter = NewFormatter(bridge)
} }
func (bridge *Bridge) Start() { func (bridge *Bridge) Start() {
err := bridge.DB.CreateTables() err := bridge.DB.CreateTables()
if err != nil { if err != nil {
bridge.Log.Fatalln("Failed to create database tables:", err) bridge.Log.Fatalln("Failed to create database tables:", err)
os.Exit(14) os.Exit(16)
} }
bridge.Log.Debugln("Starting application service HTTP server") bridge.Log.Debugln("Starting application service HTTP server")
go bridge.AppService.Start() go bridge.AS.Start()
bridge.Log.Debugln("Starting event processor") bridge.Log.Debugln("Starting event processor")
go bridge.EventProcessor.Start() go bridge.EventProcessor.Start()
go bridge.UpdateBotProfile() go bridge.UpdateBotProfile()
@ -140,18 +169,18 @@ func (bridge *Bridge) UpdateBotProfile() {
var err error var err error
if botConfig.Avatar == "remove" { if botConfig.Avatar == "remove" {
err = bridge.AppService.BotIntent().SetAvatarURL("") err = bridge.AS.BotIntent().SetAvatarURL("")
} else if len(botConfig.Avatar) > 0 { } else if len(botConfig.Avatar) > 0 {
err = bridge.AppService.BotIntent().SetAvatarURL(botConfig.Avatar) err = bridge.AS.BotIntent().SetAvatarURL(botConfig.Avatar)
} }
if err != nil { if err != nil {
bridge.Log.Warnln("Failed to update bot avatar:", err) bridge.Log.Warnln("Failed to update bot avatar:", err)
} }
if botConfig.Displayname == "remove" { if botConfig.Displayname == "remove" {
err = bridge.AppService.BotIntent().SetDisplayName("") err = bridge.AS.BotIntent().SetDisplayName("")
} else if len(botConfig.Avatar) > 0 { } else if len(botConfig.Avatar) > 0 {
err = bridge.AppService.BotIntent().SetDisplayName(botConfig.Displayname) err = bridge.AS.BotIntent().SetDisplayName(botConfig.Displayname)
} }
if err != nil { if err != nil {
bridge.Log.Warnln("Failed to update bot displayname:", err) bridge.Log.Warnln("Failed to update bot displayname:", err)
@ -165,7 +194,7 @@ func (bridge *Bridge) StartUsers() {
} }
func (bridge *Bridge) Stop() { func (bridge *Bridge) Stop() {
bridge.AppService.Stop() bridge.AS.Stop()
bridge.EventProcessor.Stop() bridge.EventProcessor.Stop()
err := bridge.StateStore.Save() err := bridge.StateStore.Save()
if err != nil { if err != nil {

View file

@ -35,7 +35,7 @@ type MatrixHandler struct {
func NewMatrixHandler(bridge *Bridge) *MatrixHandler { func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
handler := &MatrixHandler{ handler := &MatrixHandler{
bridge: bridge, bridge: bridge,
as: bridge.AppService, as: bridge.AS,
log: bridge.Log.Sub("Matrix"), log: bridge.Log.Sub("Matrix"),
cmd: NewCommandHandler(bridge), cmd: NewCommandHandler(bridge),
} }
@ -50,7 +50,7 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) { func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
intent := mx.as.BotIntent() intent := mx.as.BotIntent()
user := mx.bridge.GetUser(evt.Sender) user := mx.bridge.GetUserByMXID(evt.Sender)
if user == nil { if user == nil {
return return
} }
@ -85,7 +85,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
for mxid, _ := range members.Joined { for mxid, _ := range members.Joined {
if mxid == intent.UserID || mxid == evt.Sender { if mxid == intent.UserID || mxid == evt.Sender {
continue continue
} else if _, _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok { } else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok {
hasPuppets = true hasPuppets = true
continue continue
} }
@ -96,7 +96,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *gomatrix.Event) {
} }
if !hasPuppets { if !hasPuppets {
user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
user.SetManagementRoom(types.MatrixRoomID(resp.RoomID)) user.SetManagementRoom(types.MatrixRoomID(resp.RoomID))
intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.") intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room.")
mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender) mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
@ -110,12 +110,12 @@ func (mx *MatrixHandler) HandleMembership(evt *gomatrix.Event) {
} }
func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) { func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
if user == nil || !user.Whitelisted { if user == nil || !user.Whitelisted || !user.IsLoggedIn() {
return return
} }
portal := user.GetPortalByMXID(evt.RoomID) portal := mx.bridge.GetPortalByMXID(evt.RoomID)
if portal == nil || portal.IsPrivateChat() { if portal == nil || portal.IsPrivateChat() {
return return
} }
@ -124,7 +124,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
var err error var err error
switch evt.Type { switch evt.Type {
case gomatrix.StateRoomName: case gomatrix.StateRoomName:
resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.JID) resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID)
case gomatrix.StateRoomAvatar: case gomatrix.StateRoomAvatar:
return return
case gomatrix.StateTopic: case gomatrix.StateTopic:
@ -140,7 +140,7 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *gomatrix.Event) {
func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) { func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
roomID := types.MatrixRoomID(evt.RoomID) roomID := types.MatrixRoomID(evt.RoomID)
user := mx.bridge.GetUser(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
if !user.Whitelisted { if !user.Whitelisted {
return return
@ -158,8 +158,12 @@ func (mx *MatrixHandler) HandleMessage(evt *gomatrix.Event) {
} }
} }
portal := user.GetPortalByMXID(roomID) if !user.IsLoggedIn() {
return
}
portal := mx.bridge.GetPortalByMXID(roomID)
if portal != nil { if portal != nil {
portal.HandleMatrixMessage(evt) portal.HandleMatrixMessage(user, evt)
} }
} }

205
portal.go
View file

@ -20,7 +20,6 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"html"
"image" "image"
"image/gif" "image/gif"
"image/jpeg" "image/jpeg"
@ -41,57 +40,56 @@ import (
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
func (user *User) GetPortalByMXID(mxid types.MatrixRoomID) *Portal { func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal {
user.portalsLock.Lock() bridge.portalsLock.Lock()
defer user.portalsLock.Unlock() defer bridge.portalsLock.Unlock()
portal, ok := user.portalsByMXID[mxid] portal, ok := bridge.portalsByMXID[mxid]
if !ok { if !ok {
dbPortal := user.bridge.DB.Portal.GetByMXID(mxid) dbPortal := bridge.DB.Portal.GetByMXID(mxid)
if dbPortal == nil || dbPortal.Owner != user.ID { if dbPortal == nil {
return nil return nil
} }
portal = user.NewPortal(dbPortal) portal = bridge.NewPortal(dbPortal)
user.portalsByJID[portal.JID] = portal bridge.portalsByJID[portal.Key] = portal
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
user.portalsByMXID[portal.MXID] = portal bridge.portalsByMXID[portal.MXID] = portal
} }
} }
return portal return portal
} }
func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal { func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
user.portalsLock.Lock() bridge.portalsLock.Lock()
defer user.portalsLock.Unlock() defer bridge.portalsLock.Unlock()
portal, ok := user.portalsByJID[jid] portal, ok := bridge.portalsByJID[key]
if !ok { if !ok {
dbPortal := user.bridge.DB.Portal.GetByJID(user.ID, jid) dbPortal := bridge.DB.Portal.GetByJID(key)
if dbPortal == nil { if dbPortal == nil {
dbPortal = user.bridge.DB.Portal.New() dbPortal = bridge.DB.Portal.New()
dbPortal.JID = jid dbPortal.Key = key
dbPortal.Owner = user.ID
dbPortal.Insert() dbPortal.Insert()
} }
portal = user.NewPortal(dbPortal) portal = bridge.NewPortal(dbPortal)
user.portalsByJID[portal.JID] = portal bridge.portalsByJID[portal.Key] = portal
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
user.portalsByMXID[portal.MXID] = portal bridge.portalsByMXID[portal.MXID] = portal
} }
} }
return portal return portal
} }
func (user *User) GetAllPortals() []*Portal { func (bridge *Bridge) GetAllPortals() []*Portal {
user.portalsLock.Lock() bridge.portalsLock.Lock()
defer user.portalsLock.Unlock() defer bridge.portalsLock.Unlock()
dbPortals := user.bridge.DB.Portal.GetAll(user.ID) dbPortals := bridge.DB.Portal.GetAll()
output := make([]*Portal, len(dbPortals)) output := make([]*Portal, len(dbPortals))
for index, dbPortal := range dbPortals { for index, dbPortal := range dbPortals {
portal, ok := user.portalsByJID[dbPortal.JID] portal, ok := bridge.portalsByJID[dbPortal.Key]
if !ok { if !ok {
portal = user.NewPortal(dbPortal) portal = bridge.NewPortal(dbPortal)
user.portalsByJID[dbPortal.JID] = portal bridge.portalsByJID[portal.Key] = portal
if len(dbPortal.MXID) > 0 { if len(dbPortal.MXID) > 0 {
user.portalsByMXID[dbPortal.MXID] = portal bridge.portalsByMXID[dbPortal.MXID] = portal
} }
} }
output[index] = portal output[index] = portal
@ -99,19 +97,17 @@ func (user *User) GetAllPortals() []*Portal {
return output return output
} }
func (user *User) NewPortal(dbPortal *database.Portal) *Portal { func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
return &Portal{ return &Portal{
Portal: dbPortal, Portal: dbPortal,
user: user, bridge: bridge,
bridge: user.bridge, log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
log: user.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.JID)),
} }
} }
type Portal struct { type Portal struct {
*database.Portal *database.Portal
user *User
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
@ -126,9 +122,16 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
changed = true changed = true
} }
for _, participant := range metadata.Participants { for _, participant := range metadata.Participants {
puppet := portal.user.GetPuppetByJID(participant.JID) puppet := portal.bridge.GetPuppetByJID(participant.JID)
puppet.Intent().EnsureJoined(portal.MXID) puppet.Intent().EnsureJoined(portal.MXID)
user := portal.bridge.GetUserByJID(participant.JID)
if !portal.bridge.AS.StateStore.IsInvited(portal.MXID, user.MXID) {
portal.MainIntent().InviteUser(portal.MXID, &gomatrix.ReqInviteUser{
UserID: user.MXID,
})
}
expectedLevel := 0 expectedLevel := 0
if participant.IsSuperAdmin { if participant.IsSuperAdmin {
expectedLevel = 95 expectedLevel = 95
@ -136,9 +139,8 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
expectedLevel = 50 expectedLevel = 50
} }
changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed changed = levels.EnsureUserLevel(puppet.MXID, expectedLevel) || changed
if user != nil {
if participant.JID == portal.user.JID() { changed = levels.EnsureUserLevel(user.MXID, expectedLevel) || changed
changed = levels.EnsureUserLevel(portal.user.ID, expectedLevel) || changed
} }
} }
if changed { if changed {
@ -146,10 +148,10 @@ func (portal *Portal) SyncParticipants(metadata *whatsappExt.GroupInfo) {
} }
} }
func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { func (portal *Portal) UpdateAvatar(user *User, avatar *whatsappExt.ProfilePicInfo) bool {
if avatar == nil { if avatar == nil {
var err error var err error
avatar, err = portal.user.Conn.GetProfilePicThumb(portal.JID) avatar, err = user.Conn.GetProfilePicThumb(portal.Key.JID)
if err != nil { if err != nil {
portal.log.Errorln(err) portal.log.Errorln(err)
return false return false
@ -184,7 +186,7 @@ func (portal *Portal) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool { func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
if portal.Name != name { if portal.Name != name {
intent := portal.user.GetPuppetByJID(setBy).Intent() intent := portal.bridge.GetPuppetByJID(setBy).Intent()
_, err := intent.SetRoomName(portal.MXID, name) _, err := intent.SetRoomName(portal.MXID, name)
if err == nil { if err == nil {
portal.Name = name portal.Name = name
@ -197,7 +199,7 @@ func (portal *Portal) UpdateName(name string, setBy types.WhatsAppID) bool {
func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool { func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
if portal.Topic != topic { if portal.Topic != topic {
intent := portal.user.GetPuppetByJID(setBy).Intent() intent := portal.bridge.GetPuppetByJID(setBy).Intent()
_, err := intent.SetRoomTopic(portal.MXID, topic) _, err := intent.SetRoomTopic(portal.MXID, topic)
if err == nil { if err == nil {
portal.Topic = topic portal.Topic = topic
@ -208,8 +210,8 @@ func (portal *Portal) UpdateTopic(topic string, setBy types.WhatsAppID) bool {
return false return false
} }
func (portal *Portal) UpdateMetadata() bool { func (portal *Portal) UpdateMetadata(user *User) bool {
metadata, err := portal.user.Conn.GetGroupMetaData(portal.JID) metadata, err := user.Conn.GetGroupMetaData(portal.Key.JID)
if err != nil { if err != nil {
portal.log.Errorln(err) portal.log.Errorln(err)
return false return false
@ -221,25 +223,23 @@ func (portal *Portal) UpdateMetadata() bool {
return update return update
} }
func (portal *Portal) Sync(contact whatsapp.Contact) { func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
if portal.IsPrivateChat() {
return
}
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
if !portal.IsPrivateChat() { portal.Name = contact.Name
portal.Name = contact.Name err := portal.CreateMatrixRoom([]string{user.MXID})
}
err := portal.CreateMatrixRoom()
if err != nil { if err != nil {
portal.log.Errorln("Failed to create portal room:", err) portal.log.Errorln("Failed to create portal room:", err)
return return
} }
} }
if portal.IsPrivateChat() {
return
}
update := false update := false
update = portal.UpdateMetadata() || update update = portal.UpdateMetadata(user) || update
update = portal.UpdateAvatar(nil) || update update = portal.UpdateAvatar(user, nil) || update
if update { if update {
portal.Update() portal.Update()
} }
@ -277,11 +277,12 @@ func (portal *Portal) ChangeAdminStatus(jids []string, setAdmin bool) {
} }
changed := false changed := false
for _, jid := range jids { for _, jid := range jids {
puppet := portal.user.GetPuppetByJID(jid) puppet := portal.bridge.GetPuppetByJID(jid)
changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed changed = levels.EnsureUserLevel(puppet.MXID, newLevel) || changed
if jid == portal.user.JID() { user := portal.bridge.GetUserByJID(jid)
changed = levels.EnsureUserLevel(portal.user.ID, newLevel) || changed if user != nil {
changed = levels.EnsureUserLevel(user.MXID, newLevel) || changed
} }
} }
if changed { if changed {
@ -312,15 +313,15 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
newLevel = 50 newLevel = 50
} }
changed := false changed := false
changed = levels.EnsureEventLevel(gomatrix.StateRoomName, true, newLevel) || changed changed = levels.EnsureEventLevel(gomatrix.StateRoomName, newLevel) || changed
changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, true, newLevel) || changed changed = levels.EnsureEventLevel(gomatrix.StateRoomAvatar, newLevel) || changed
changed = levels.EnsureEventLevel(gomatrix.StateTopic, true, newLevel) || changed changed = levels.EnsureEventLevel(gomatrix.StateTopic, newLevel) || changed
if changed { if changed {
portal.MainIntent().SetPowerLevels(portal.MXID, levels) portal.MainIntent().SetPowerLevels(portal.MXID, levels)
} }
} }
func (portal *Portal) CreateMatrixRoom() error { func (portal *Portal) CreateMatrixRoom(invite []string) error {
portal.roomCreateLock.Lock() portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock() defer portal.roomCreateLock.Unlock()
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
@ -330,7 +331,6 @@ func (portal *Portal) CreateMatrixRoom() error {
name := portal.Name name := portal.Name
topic := portal.Topic topic := portal.Topic
isPrivateChat := false isPrivateChat := false
invite := []string{portal.user.ID}
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
name = "" name = ""
topic = "WhatsApp private chat" topic = "WhatsApp private chat"
@ -360,18 +360,18 @@ func (portal *Portal) CreateMatrixRoom() error {
} }
func (portal *Portal) IsPrivateChat() bool { func (portal *Portal) IsPrivateChat() bool {
return strings.HasSuffix(portal.JID, whatsappExt.NewUserSuffix) return strings.HasSuffix(portal.Key.JID, whatsappExt.NewUserSuffix)
} }
func (portal *Portal) MainIntent() *appservice.IntentAPI { func (portal *Portal) MainIntent() *appservice.IntentAPI {
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
return portal.user.GetPuppetByJID(portal.JID).Intent() return portal.bridge.GetPuppetByJID(portal.Key.JID).Intent()
} }
return portal.bridge.AppService.BotIntent() return portal.bridge.AS.BotIntent()
} }
func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool { func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
msg := portal.bridge.DB.Message.GetByJID(portal.Owner, id) msg := portal.bridge.DB.Message.GetByJID(portal.Key, id)
if msg != nil { if msg != nil {
return true return true
} }
@ -380,7 +380,7 @@ func (portal *Portal) IsDuplicate(id types.WhatsAppMessageID) bool {
func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) { func (portal *Portal) MarkHandled(jid types.WhatsAppMessageID, mxid types.MatrixEventID) {
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Owner = portal.Owner msg.Chat = portal.Key
msg.JID = jid msg.JID = jid
msg.MXID = mxid msg.MXID = mxid
msg.Insert() msg.Insert()
@ -392,7 +392,7 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
// TODO handle own messages in private chats properly // TODO handle own messages in private chats properly
return nil return nil
} }
return portal.user.GetPuppetByJID(portal.user.JID()).Intent() return portal.bridge.GetPuppetByJID(portal.Key.Receiver).Intent()
} else if portal.IsPrivateChat() { } else if portal.IsPrivateChat() {
return portal.MainIntent() return portal.MainIntent()
} else if len(info.SenderJid) == 0 { } else if len(info.SenderJid) == 0 {
@ -402,14 +402,14 @@ func (portal *Portal) GetMessageIntent(info whatsapp.MessageInfo) *appservice.In
return nil return nil
} }
} }
return portal.user.GetPuppetByJID(info.SenderJid).Intent() return portal.bridge.GetPuppetByJID(info.SenderJid).Intent()
} }
func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) { func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageInfo) {
if len(info.QuotedMessageID) == 0 { if len(info.QuotedMessageID) == 0 {
return return
} }
message := portal.bridge.DB.Message.GetByJID(portal.Owner, info.QuotedMessageID) message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID)
if message != nil { if message != nil {
event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID) event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
if err != nil { if err != nil {
@ -421,29 +421,12 @@ func (portal *Portal) SetReply(content *gomatrix.Content, info whatsapp.MessageI
return return
} }
func (portal *Portal) FormatWhatsAppMessage(content *gomatrix.Content) { func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) {
output := html.EscapeString(content.Body)
for regex, replacement := range portal.user.waReplString {
output = regex.ReplaceAllString(output, replacement)
}
for regex, replacer := range portal.user.waReplFunc {
output = regex.ReplaceAllStringFunc(output, replacer)
}
if output != content.Body {
content.FormattedBody = output
content.Format = gomatrix.FormatHTML
for regex, replacer := range portal.user.waReplFuncText {
content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
}
}
}
func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
if portal.IsDuplicate(message.Info.Id) { if portal.IsDuplicate(message.Info.Id) {
return return
} }
err := portal.CreateMatrixRoom() err := portal.CreateMatrixRoom([]string{source.MXID})
if err != nil { if err != nil {
portal.log.Errorln("Failed to create portal room:", err) portal.log.Errorln("Failed to create portal room:", err)
return return
@ -459,7 +442,7 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
MsgType: gomatrix.MsgText, MsgType: gomatrix.MsgText,
} }
portal.FormatWhatsAppMessage(content) portal.bridge.Formatter.ParseWhatsApp(content)
portal.SetReply(content, message.Info) portal.SetReply(content, message.Info)
intent.UserTyping(portal.MXID, false, 0) intent.UserTyping(portal.MXID, false, 0)
@ -472,12 +455,12 @@ func (portal *Portal) HandleTextMessage(message whatsapp.TextMessage) {
portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID) portal.log.Debugln("Handled message", message.Info.Id, "->", resp.EventID)
} }
func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) { func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte, error), thumbnail []byte, info whatsapp.MessageInfo, mimeType, caption string) {
if portal.IsDuplicate(info.Id) { if portal.IsDuplicate(info.Id) {
return return
} }
err := portal.CreateMatrixRoom() err := portal.CreateMatrixRoom([]string{source.MXID})
if err != nil { if err != nil {
portal.log.Errorln("Failed to create portal room:", err) portal.log.Errorln("Failed to create portal room:", err)
return return
@ -559,7 +542,7 @@ func (portal *Portal) HandleMediaMessage(download func() ([]byte, error), thumbn
MsgType: gomatrix.MsgNotice, MsgType: gomatrix.MsgNotice,
} }
portal.FormatWhatsAppMessage(captionContent) portal.bridge.Formatter.ParseWhatsApp(captionContent)
_, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts) _, err := intent.SendMassagedMessageEvent(portal.MXID, gomatrix.EventMessage, captionContent, ts)
if err != nil { if err != nil {
@ -612,7 +595,7 @@ func (portal *Portal) downloadThumbnail(evt *gomatrix.Event) []byte {
return buf.Bytes() return buf.Bytes()
} }
func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload { func (portal *Portal) preprocessMatrixMedia(sender *User, evt *gomatrix.Event, mediaType whatsapp.MediaType) *MediaUpload {
if evt.Content.Info == nil { if evt.Content.Info == nil {
evt.Content.Info = &gomatrix.FileInfo{} evt.Content.Info = &gomatrix.FileInfo{}
} }
@ -630,7 +613,7 @@ func (portal *Portal) preprocessMatrixMedia(evt *gomatrix.Event, mediaType whats
return nil return nil
} }
url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := portal.user.Conn.Upload(bytes.NewReader(content), mediaType) url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err) portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err)
return nil return nil
@ -657,8 +640,8 @@ type MediaUpload struct {
Thumbnail []byte Thumbnail []byte
} }
func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessageInfo { func (portal *Portal) GetMessage(user *User, jid types.WhatsAppMessageID) *waProto.WebMessageInfo {
node, err := portal.user.Conn.LoadMessagesBefore(portal.JID, jid, 1) node, err := user.Conn.LoadMessagesBefore(portal.Key.JID, jid, 1)
if err != nil { if err != nil {
return nil return nil
} }
@ -670,7 +653,7 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
if !ok { if !ok {
return nil return nil
} }
node, err = portal.user.Conn.LoadMessagesAfter(portal.JID, msg.GetKey().GetId(), 1) node, err = user.Conn.LoadMessagesAfter(portal.Key.JID, msg.GetKey().GetId(), 1)
if err != nil { if err != nil {
return nil return nil
} }
@ -682,7 +665,11 @@ func (portal *Portal) GetMessage(jid types.WhatsAppMessageID) *waProto.WebMessag
return msg return msg
} }
func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) { func (portal *Portal) HandleMatrixMessage(sender *User, evt *gomatrix.Event) {
if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver {
return
}
ts := uint64(evt.Timestamp / 1000) ts := uint64(evt.Timestamp / 1000)
status := waProto.WebMessageInfo_ERROR status := waProto.WebMessageInfo_ERROR
fromMe := true fromMe := true
@ -690,7 +677,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
Key: &waProto.MessageKey{ Key: &waProto.MessageKey{
FromMe: &fromMe, FromMe: &fromMe,
Id: makeMessageID(), Id: makeMessageID(),
RemoteJid: &portal.JID, RemoteJid: &portal.Key.JID,
}, },
MessageTimestamp: &ts, MessageTimestamp: &ts,
Message: &waProto.Message{}, Message: &waProto.Message{},
@ -702,12 +689,12 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
evt.Content.RemoveReplyFallback() evt.Content.RemoveReplyFallback()
msg := portal.bridge.DB.Message.GetByMXID(replyToID) msg := portal.bridge.DB.Message.GetByMXID(replyToID)
if msg != nil { if msg != nil {
origMsg := portal.GetMessage(msg.JID) origMsg := portal.GetMessage(sender, msg.JID)
if origMsg != nil { if origMsg != nil {
ctxInfo.StanzaId = &msg.JID ctxInfo.StanzaId = &msg.JID
replyMsgSender := origMsg.GetParticipant() replyMsgSender := origMsg.GetParticipant()
if origMsg.GetKey().GetFromMe() { if origMsg.GetKey().GetFromMe() {
replyMsgSender = portal.user.JID() replyMsgSender = sender.JID
} }
ctxInfo.Participant = &replyMsgSender ctxInfo.Participant = &replyMsgSender
ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message} ctxInfo.QuotedMessage = []*waProto.Message{origMsg.Message}
@ -719,7 +706,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
case gomatrix.MsgText, gomatrix.MsgEmote: case gomatrix.MsgText, gomatrix.MsgEmote:
text := evt.Content.Body text := evt.Content.Body
if evt.Content.Format == gomatrix.FormatHTML { if evt.Content.Format == gomatrix.FormatHTML {
text = portal.user.htmlParser.Parse(evt.Content.FormattedBody) text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody)
} }
if evt.Content.MsgType == gomatrix.MsgEmote { if evt.Content.MsgType == gomatrix.MsgEmote {
text = "/me " + text text = "/me " + text
@ -737,7 +724,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
info.Message.Conversation = &text info.Message.Conversation = &text
} }
case gomatrix.MsgImage: case gomatrix.MsgImage:
media := portal.preprocessMatrixMedia(evt, whatsapp.MediaImage) media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaImage)
if media == nil { if media == nil {
return return
} }
@ -752,7 +739,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case gomatrix.MsgVideo: case gomatrix.MsgVideo:
media := portal.preprocessMatrixMedia(evt, whatsapp.MediaVideo) media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaVideo)
if media == nil { if media == nil {
return return
} }
@ -769,7 +756,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case gomatrix.MsgAudio: case gomatrix.MsgAudio:
media := portal.preprocessMatrixMedia(evt, whatsapp.MediaAudio) media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaAudio)
if media == nil { if media == nil {
return return
} }
@ -784,7 +771,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case gomatrix.MsgFile: case gomatrix.MsgFile:
media := portal.preprocessMatrixMedia(evt, whatsapp.MediaDocument) media := portal.preprocessMatrixMedia(sender, evt, whatsapp.MediaDocument)
if media == nil { if media == nil {
return return
} }
@ -800,7 +787,7 @@ func (portal *Portal) HandleMatrixMessage(evt *gomatrix.Event) {
portal.log.Debugln("Unhandled Matrix event:", evt) portal.log.Debugln("Unhandled Matrix event:", evt)
return return
} }
err = portal.user.Conn.Send(info) err = sender.Conn.Send(info)
portal.MarkHandled(info.GetKey().GetId(), evt.ID) portal.MarkHandled(info.GetKey().GetId(), evt.ID)
if err != nil { if err != nil {
portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err) portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)

View file

@ -30,105 +30,83 @@ import (
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.MatrixUserID, types.WhatsAppID, bool) { func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) {
userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$", userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
bridge.Config.Bridge.FormatUsername("(.+)", "([0-9]+)"), bridge.Config.Bridge.FormatUsername("([0-9]+)"),
bridge.Config.Homeserver.Domain)) bridge.Config.Homeserver.Domain))
if err != nil { if err != nil {
bridge.Log.Warnln("Failed to compile puppet user ID regex:", err) bridge.Log.Warnln("Failed to compile puppet user ID regex:", err)
return "", "", false return "", false
} }
match := userIDRegex.FindStringSubmatch(string(mxid)) match := userIDRegex.FindStringSubmatch(string(mxid))
if match == nil || len(match) != 3 { if match == nil || len(match) != 2 {
return "", "", false return "", false
} }
receiver := types.MatrixUserID(match[1])
receiver = strings.Replace(receiver, "=40", "@", 1)
colonIndex := strings.LastIndex(receiver, "=3")
receiver = receiver[:colonIndex] + ":" + receiver[colonIndex+len("=3"):]
jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix) jid := types.WhatsAppID(match[2] + whatsappExt.NewUserSuffix)
return receiver, jid, true return jid, true
} }
func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet { func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet {
receiver, jid, ok := bridge.ParsePuppetMXID(mxid) jid, ok := bridge.ParsePuppetMXID(mxid)
if !ok { if !ok {
return nil return nil
} }
user := bridge.GetUser(receiver) return bridge.GetPuppetByJID(jid)
if user == nil {
return nil
}
return user.GetPuppetByJID(jid)
} }
func (user *User) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet { func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
receiver, jid, ok := user.bridge.ParsePuppetMXID(mxid) bridge.puppetsLock.Lock()
if !ok || receiver != user.ID { defer bridge.puppetsLock.Unlock()
return nil puppet, ok := bridge.puppets[jid]
}
return user.GetPuppetByJID(jid)
}
func (user *User) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
user.puppetsLock.Lock()
defer user.puppetsLock.Unlock()
puppet, ok := user.puppets[jid]
if !ok { if !ok {
dbPuppet := user.bridge.DB.Puppet.Get(jid, user.ID) dbPuppet := bridge.DB.Puppet.Get(jid)
if dbPuppet == nil { if dbPuppet == nil {
dbPuppet = user.bridge.DB.Puppet.New() dbPuppet = bridge.DB.Puppet.New()
dbPuppet.JID = jid dbPuppet.JID = jid
dbPuppet.Receiver = user.ID
dbPuppet.Insert() dbPuppet.Insert()
} }
puppet = user.NewPuppet(dbPuppet) puppet = bridge.NewPuppet(dbPuppet)
user.puppets[puppet.JID] = puppet bridge.puppets[puppet.JID] = puppet
} }
return puppet return puppet
} }
func (user *User) GetAllPuppets() []*Puppet { func (bridge *Bridge) GetAllPuppets() []*Puppet {
user.puppetsLock.Lock() bridge.puppetsLock.Lock()
defer user.puppetsLock.Unlock() defer bridge.puppetsLock.Unlock()
dbPuppets := user.bridge.DB.Puppet.GetAll(user.ID) dbPuppets := bridge.DB.Puppet.GetAll()
output := make([]*Puppet, len(dbPuppets)) output := make([]*Puppet, len(dbPuppets))
for index, dbPuppet := range dbPuppets { for index, dbPuppet := range dbPuppets {
puppet, ok := user.puppets[dbPuppet.JID] puppet, ok := bridge.puppets[dbPuppet.JID]
if !ok { if !ok {
puppet = user.NewPuppet(dbPuppet) puppet = bridge.NewPuppet(dbPuppet)
user.puppets[dbPuppet.JID] = puppet bridge.puppets[dbPuppet.JID] = puppet
} }
output[index] = puppet output[index] = puppet
} }
return output return output
} }
func (user *User) NewPuppet(dbPuppet *database.Puppet) *Puppet { func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{ return &Puppet{
Puppet: dbPuppet, Puppet: dbPuppet,
user: user, bridge: bridge,
bridge: user.bridge, log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
log: user.log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
MXID: fmt.Sprintf("@%s:%s", MXID: fmt.Sprintf("@%s:%s",
user.bridge.Config.Bridge.FormatUsername( bridge.Config.Bridge.FormatUsername(
dbPuppet.Receiver,
strings.Replace( strings.Replace(
dbPuppet.JID, dbPuppet.JID,
whatsappExt.NewUserSuffix, "", 1)), whatsappExt.NewUserSuffix, "", 1)),
user.bridge.Config.Homeserver.Domain), bridge.Config.Homeserver.Domain),
} }
} }
type Puppet struct { type Puppet struct {
*database.Puppet *database.Puppet
user *User
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
@ -143,13 +121,13 @@ func (puppet *Puppet) PhoneNumber() string {
} }
func (puppet *Puppet) Intent() *appservice.IntentAPI { func (puppet *Puppet) Intent() *appservice.IntentAPI {
return puppet.bridge.AppService.Intent(puppet.MXID) return puppet.bridge.AS.Intent(puppet.MXID)
} }
func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool { func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicInfo) bool {
if avatar == nil { if avatar == nil {
var err error var err error
avatar, err = puppet.user.Conn.GetProfilePicThumb(puppet.JID) avatar, err = source.Conn.GetProfilePicThumb(puppet.JID)
if err != nil { if err != nil {
puppet.log.Errorln(err) puppet.log.Errorln(err)
return false return false
@ -184,11 +162,11 @@ func (puppet *Puppet) UpdateAvatar(avatar *whatsappExt.ProfilePicInfo) bool {
return true return true
} }
func (puppet *Puppet) Sync(contact whatsapp.Contact) { func (puppet *Puppet) Sync(source *User, contact whatsapp.Contact) {
puppet.Intent().EnsureRegistered() puppet.Intent().EnsureRegistered()
if contact.Jid == puppet.user.JID() { if contact.Jid == source.JID {
contact.Notify = puppet.user.Conn.Info.Pushname contact.Notify = source.Conn.Info.Pushname
} }
newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact) newName := puppet.bridge.Config.Bridge.FormatDisplayname(contact)
if puppet.Displayname != newName { if puppet.Displayname != newName {
@ -201,7 +179,7 @@ func (puppet *Puppet) Sync(contact whatsapp.Contact) {
} }
} }
if puppet.UpdateAvatar(nil) { if puppet.UpdateAvatar(source, nil) {
puppet.Update() puppet.Update()
} }
} }

134
user.go
View file

@ -17,14 +17,11 @@
package main package main
import ( import (
"regexp"
"strings" "strings"
"sync"
"time" "time"
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
"maunium.net/go/gomatrix/format"
log "maunium.net/go/maulogger" log "maunium.net/go/maulogger"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
@ -41,31 +38,42 @@ type User struct {
Admin bool Admin bool
Whitelisted bool Whitelisted bool
jid string jid string
portalsByMXID map[types.MatrixRoomID]*Portal
portalsByJID map[types.WhatsAppID]*Portal
portalsLock sync.Mutex
puppets map[types.WhatsAppID]*Puppet
puppetsLock sync.Mutex
htmlParser *format.HTMLParser
waReplString map[*regexp.Regexp]string
waReplFunc map[*regexp.Regexp]func(string) string
waReplFuncText map[*regexp.Regexp]func(string) string
} }
func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User { func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User {
user, ok := bridge.users[userID] bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
user, ok := bridge.usersByMXID[userID]
if !ok { if !ok {
dbUser := bridge.DB.User.Get(userID) dbUser := bridge.DB.User.GetByMXID(userID)
if dbUser == nil { if dbUser == nil {
dbUser = bridge.DB.User.New() dbUser = bridge.DB.User.New()
dbUser.ID = userID dbUser.MXID = userID
dbUser.Insert() dbUser.Insert()
} }
user = bridge.NewUser(dbUser) user = bridge.NewUser(dbUser)
bridge.users[user.ID] = user bridge.usersByMXID[user.MXID] = user
if len(user.ManagementRoom) > 0 {
bridge.managementRooms[user.ManagementRoom] = user
}
}
return user
}
func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
user, ok := bridge.usersByJID[userID]
if !ok {
dbUser := bridge.DB.User.GetByMXID(userID)
if dbUser == nil {
dbUser = bridge.DB.User.New()
dbUser.MXID = userID
dbUser.Insert()
}
user = bridge.NewUser(dbUser)
bridge.usersByJID[user.JID] = user
if len(user.ManagementRoom) > 0 { if len(user.ManagementRoom) > 0 {
bridge.managementRooms[user.ManagementRoom] = user bridge.managementRooms[user.ManagementRoom] = user
} }
@ -74,13 +82,15 @@ func (bridge *Bridge) GetUser(userID types.MatrixUserID) *User {
} }
func (bridge *Bridge) GetAllUsers() []*User { func (bridge *Bridge) GetAllUsers() []*User {
bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
dbUsers := bridge.DB.User.GetAll() dbUsers := bridge.DB.User.GetAll()
output := make([]*User, len(dbUsers)) output := make([]*User, len(dbUsers))
for index, dbUser := range dbUsers { for index, dbUser := range dbUsers {
user, ok := bridge.users[dbUser.ID] user, ok := bridge.usersByMXID[dbUser.MXID]
if !ok { if !ok {
user = bridge.NewUser(dbUser) user = bridge.NewUser(dbUser)
bridge.users[user.ID] = user bridge.usersByMXID[user.MXID] = user
if len(user.ManagementRoom) > 0 { if len(user.ManagementRoom) > 0 {
bridge.managementRooms[user.ManagementRoom] = user bridge.managementRooms[user.ManagementRoom] = user
} }
@ -94,15 +104,10 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
user := &User{ user := &User{
User: dbUser, User: dbUser,
bridge: bridge, bridge: bridge,
log: bridge.Log.Sub("User").Sub(string(dbUser.ID)), log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
portalsByMXID: make(map[types.MatrixRoomID]*Portal),
portalsByJID: make(map[types.WhatsAppID]*Portal),
puppets: make(map[types.WhatsAppID]*Puppet),
} }
user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.ID) user.Whitelisted = user.bridge.Config.Bridge.Permissions.IsWhitelisted(user.MXID)
user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.ID) user.Admin = user.bridge.Config.Bridge.Permissions.IsAdmin(user.MXID)
user.htmlParser = user.newHTMLParser()
user.waReplString, user.waReplFunc, user.waReplFuncText = user.newWhatsAppFormatMaps()
return user return user
} }
@ -152,7 +157,6 @@ func (user *User) RestoreSession() bool {
sess, err := user.Conn.RestoreSession(*user.Session) sess, err := user.Conn.RestoreSession(*user.Session)
if err != nil { if err != nil {
user.log.Errorln("Failed to restore session:", err) user.log.Errorln("Failed to restore session:", err)
//user.SetSession(nil)
return false return false
} }
user.SetSession(&sess) user.SetSession(&sess)
@ -162,8 +166,12 @@ func (user *User) RestoreSession() bool {
return false return false
} }
func (user *User) IsLoggedIn() bool {
return user.Conn != nil
}
func (user *User) Login(roomID types.MatrixRoomID) { func (user *User) Login(roomID types.MatrixRoomID) {
bot := user.bridge.AppService.BotClient() bot := user.bridge.AS.BotClient()
qrChan := make(chan string, 2) qrChan := make(chan string, 2)
go func() { go func() {
@ -194,38 +202,24 @@ func (user *User) Login(roomID types.MatrixRoomID) {
qrChan <- "error" qrChan <- "error"
return return
} }
user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
user.Session = &session user.Session = &session
user.Update() user.Update()
bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...") bot.SendNotice(roomID, "Successfully logged in. Synchronizing chats...")
go user.Sync() go user.Sync()
} }
func (user *User) JID() string {
if user.Conn == nil {
return ""
}
if len(user.jid) == 0 {
user.jid = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
}
return user.jid
}
func (user *User) Sync() { func (user *User) Sync() {
user.log.Debugln("Syncing...") user.log.Debugln("Syncing...")
user.Conn.Contacts() user.Conn.Contacts()
for jid, contact := range user.Conn.Store.Contacts { for jid, contact := range user.Conn.Store.Contacts {
if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) { if strings.HasSuffix(jid, whatsappExt.NewUserSuffix) {
puppet := user.GetPuppetByJID(contact.Jid) puppet := user.bridge.GetPuppetByJID(contact.Jid)
puppet.Sync(contact) puppet.Sync(user, contact)
} else {
portal := user.bridge.GetPortalByJID(database.GroupPortalKey(contact.Jid))
portal.Sync(user, contact)
} }
if len(contact.Notify) == 0 && !strings.HasSuffix(jid, "@g.us") {
// No messages sent -> don't bridge
continue
}
portal := user.GetPortalByJID(contact.Jid)
portal.Sync(contact)
} }
} }
@ -237,33 +231,41 @@ func (user *User) HandleJSONParseError(err error) {
user.log.Errorln("WhatsApp JSON parse error:", err) user.log.Errorln("WhatsApp JSON parse error:", err)
} }
func (user *User) PortalKey(jid types.WhatsAppID) database.PortalKey {
return database.NewPortalKey(jid, user.JID)
}
func (user *User) GetPortalByJID(jid types.WhatsAppID) *Portal {
return user.bridge.GetPortalByJID(user.PortalKey(jid))
}
func (user *User) HandleTextMessage(message whatsapp.TextMessage) { func (user *User) HandleTextMessage(message whatsapp.TextMessage) {
portal := user.GetPortalByJID(message.Info.RemoteJid) portal := user.GetPortalByJID(message.Info.RemoteJid)
portal.HandleTextMessage(message) portal.HandleTextMessage(user, message)
} }
func (user *User) HandleImageMessage(message whatsapp.ImageMessage) { func (user *User) HandleImageMessage(message whatsapp.ImageMessage) {
portal := user.GetPortalByJID(message.Info.RemoteJid) portal := user.GetPortalByJID(message.Info.RemoteJid)
portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
} }
func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) { func (user *User) HandleVideoMessage(message whatsapp.VideoMessage) {
portal := user.GetPortalByJID(message.Info.RemoteJid) portal := user.GetPortalByJID(message.Info.RemoteJid)
portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Caption) portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Caption)
} }
func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) { func (user *User) HandleAudioMessage(message whatsapp.AudioMessage) {
portal := user.GetPortalByJID(message.Info.RemoteJid) portal := user.GetPortalByJID(message.Info.RemoteJid)
portal.HandleMediaMessage(message.Download, nil, message.Info, message.Type, "") portal.HandleMediaMessage(user, message.Download, nil, message.Info, message.Type, "")
} }
func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) { func (user *User) HandleDocumentMessage(message whatsapp.DocumentMessage) {
portal := user.GetPortalByJID(message.Info.RemoteJid) portal := user.GetPortalByJID(message.Info.RemoteJid)
portal.HandleMediaMessage(message.Download, message.Thumbnail, message.Info, message.Type, message.Title) portal.HandleMediaMessage(user, message.Download, message.Thumbnail, message.Info, message.Type, message.Title)
} }
func (user *User) HandlePresence(info whatsappExt.Presence) { func (user *User) HandlePresence(info whatsappExt.Presence) {
puppet := user.GetPuppetByJID(info.SenderJID) puppet := user.bridge.GetPuppetByJID(info.SenderJID)
switch info.Status { switch info.Status {
case whatsappExt.PresenceUnavailable: case whatsappExt.PresenceUnavailable:
puppet.Intent().SetPresence("offline") puppet.Intent().SetPresence("offline")
@ -277,6 +279,12 @@ func (user *User) HandlePresence(info whatsappExt.Presence) {
} }
case whatsappExt.PresenceComposing: case whatsappExt.PresenceComposing:
portal := user.GetPortalByJID(info.JID) portal := user.GetPortalByJID(info.JID)
if len(puppet.typingIn) > 0 && puppet.typingAt+15 > time.Now().Unix() {
if puppet.typingIn == portal.MXID {
return
}
puppet.Intent().UserTyping(puppet.typingIn, false, 0)
}
puppet.typingIn = portal.MXID puppet.typingIn = portal.MXID
puppet.typingAt = time.Now().Unix() puppet.typingAt = time.Now().Unix()
puppet.Intent().UserTyping(portal.MXID, true, 15*1000) puppet.Intent().UserTyping(portal.MXID, true, 15*1000)
@ -290,9 +298,9 @@ func (user *User) HandleMsgInfo(info whatsappExt.MsgInfo) {
return return
} }
intent := user.GetPuppetByJID(info.SenderJID).Intent() intent := user.bridge.GetPuppetByJID(info.SenderJID).Intent()
for _, id := range info.IDs { for _, id := range info.IDs {
msg := user.bridge.DB.Message.GetByJID(user.ID, id) msg := user.bridge.DB.Message.GetByJID(portal.Key, id)
if msg == nil { if msg == nil {
continue continue
} }
@ -308,11 +316,11 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) {
switch cmd.Type { switch cmd.Type {
case whatsappExt.CommandPicture: case whatsappExt.CommandPicture:
if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) { if strings.HasSuffix(cmd.JID, whatsappExt.NewUserSuffix) {
puppet := user.GetPuppetByJID(cmd.JID) puppet := user.bridge.GetPuppetByJID(cmd.JID)
puppet.UpdateAvatar(cmd.ProfilePicInfo) puppet.UpdateAvatar(user, cmd.ProfilePicInfo)
} else { } else {
portal := user.GetPortalByJID(cmd.JID) portal := user.GetPortalByJID(cmd.JID)
portal.UpdateAvatar(cmd.ProfilePicInfo) portal.UpdateAvatar(user, cmd.ProfilePicInfo)
} }
} }
} }

View file

@ -463,7 +463,7 @@ func (cli *Client) SetAvatarURL(url string) (err error) {
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) { func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJSON interface{}) (resp *RespSendEvent, err error) {
txnID := txnID() txnID := txnID()
urlPath := cli.BuildURL("rooms", roomID, "send", string(eventType), txnID) urlPath := cli.BuildURL("rooms", roomID, "send", eventType.String(), txnID)
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
return return
} }
@ -472,7 +472,7 @@ func (cli *Client) SendMessageEvent(roomID string, eventType EventType, contentJ
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
txnID := txnID() txnID := txnID()
urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", string(eventType), txnID}, map[string]string{ urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "send", eventType.String(), txnID}, map[string]string{
"ts": strconv.FormatInt(ts, 10), "ts": strconv.FormatInt(ts, 10),
}) })
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@ -482,7 +482,7 @@ func (cli *Client) SendMassagedMessageEvent(roomID string, eventType EventType,
// SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
urlPath := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey) urlPath := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
return return
} }
@ -490,7 +490,7 @@ func (cli *Client) SendStateEvent(roomID string, eventType EventType, stateKey s
// SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey // SendStateEvent sends a state event into a room. See http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-state-eventtype-statekey
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) {
urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", string(eventType), stateKey}, map[string]string{ urlPath := cli.BuildURLWithQuery([]string{"rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
"ts": strconv.FormatInt(ts, 10), "ts": strconv.FormatInt(ts, 10),
}) })
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
@ -500,7 +500,7 @@ func (cli *Client) SendMassagedStateEvent(roomID string, eventType EventType, st
// SendText sends an m.room.message event into the given room with a msgtype of m.text // SendText sends an m.room.message event into the given room with a msgtype of m.text
// See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-text
func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) { func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, "m.room.message", Content{ return cli.SendMessageEvent(roomID, EventMessage, Content{
MsgType: MsgText, MsgType: MsgText,
Body: text, Body: text,
}) })
@ -509,7 +509,7 @@ func (cli *Client) SendText(roomID, text string) (*RespSendEvent, error) {
// SendImage sends an m.room.message event into the given room with a msgtype of m.image // SendImage sends an m.room.message event into the given room with a msgtype of m.image
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-image
func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) { func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, "m.room.message", Content{ return cli.SendMessageEvent(roomID, EventMessage, Content{
MsgType: MsgImage, MsgType: MsgImage,
Body: body, Body: body,
URL: url, URL: url,
@ -519,7 +519,7 @@ func (cli *Client) SendImage(roomID, body, url string) (*RespSendEvent, error) {
// SendVideo sends an m.room.message event into the given room with a msgtype of m.video // SendVideo sends an m.room.message event into the given room with a msgtype of m.video
// See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video // See https://matrix.org/docs/spec/client_server/r0.2.0.html#m-video
func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) { func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, "m.room.message", Content{ return cli.SendMessageEvent(roomID, EventMessage, Content{
MsgType: MsgVideo, MsgType: MsgVideo,
Body: body, Body: body,
URL: url, URL: url,
@ -529,7 +529,7 @@ func (cli *Client) SendVideo(roomID, body, url string) (*RespSendEvent, error) {
// SendNotice sends an m.room.message event into the given room with a msgtype of m.notice // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
// See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice // See http://matrix.org/docs/spec/client_server/r0.2.0.html#m-notice
func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) { func (cli *Client) SendNotice(roomID, text string) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, "m.room.message", Content{ return cli.SendMessageEvent(roomID, EventMessage, Content{
MsgType: MsgNotice, MsgType: MsgNotice,
Body: text, Body: text,
}) })
@ -622,7 +622,7 @@ func (cli *Client) SetPresence(status string) (err error) {
// the HTTP response body, or return an error. // the HTTP response body, or return an error.
// See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey // See http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-rooms-roomid-state-eventtype-statekey
func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) { func (cli *Client) StateEvent(roomID string, eventType EventType, stateKey string, outContent interface{}) (err error) {
u := cli.BuildURL("rooms", roomID, "state", string(eventType), stateKey) u := cli.BuildURL("rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest("GET", u, nil, outContent) _, err = cli.MakeRequest("GET", u, nil, outContent)
return return
} }

View file

@ -5,28 +5,44 @@ import (
"sync" "sync"
) )
type EventType string type EventType struct {
Type string
IsState bool
}
func (et *EventType) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &et.Type)
}
func (et *EventType) MarshalJSON() ([]byte, error) {
return json.Marshal(&et.Type)
}
func (et *EventType) String() string {
return et.Type
}
type MessageType string type MessageType string
// State events // State events
const ( var (
StateAliases EventType = "m.room.aliases" StateAliases = EventType{"m.room.aliases", true}
StateCanonicalAlias = "m.room.canonical_alias" StateCanonicalAlias = EventType{"m.room.canonical_alias", true}
StateCreate = "m.room.create" StateCreate = EventType{"m.room.create", true}
StateJoinRules = "m.room.join_rules" StateJoinRules = EventType{"m.room.join_rules", true}
StateMember = "m.room.member" StateMember = EventType{"m.room.member", true}
StatePowerLevels = "m.room.power_levels" StatePowerLevels = EventType{"m.room.power_levels", true}
StateRoomName = "m.room.name" StateRoomName = EventType{"m.room.name", true}
StateTopic = "m.room.topic" StateTopic = EventType{"m.room.topic", true}
StateRoomAvatar = "m.room.avatar" StateRoomAvatar = EventType{"m.room.avatar", true}
StatePinnedEvents = "m.room.pinned_events" StatePinnedEvents = EventType{"m.room.pinned_events", true}
) )
// Message events // Message events
const ( var (
EventRedaction EventType = "m.room.redaction" EventRedaction = EventType{"m.room.redaction", false}
EventMessage = "m.room.message" EventMessage = EventType{"m.room.message", false}
EventSticker = "m.sticker" EventSticker = EventType{"m.sticker", false}
) )
// Msgtypes // Msgtypes
@ -258,12 +274,12 @@ func (pl *PowerLevels) EnsureUserLevel(userID string, level int) bool {
return false return false
} }
func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int { func (pl *PowerLevels) GetEventLevel(eventType EventType) int {
pl.eventsLock.RLock() pl.eventsLock.RLock()
defer pl.eventsLock.RUnlock() defer pl.eventsLock.RUnlock()
level, ok := pl.Events[eventType] level, ok := pl.Events[eventType]
if !ok { if !ok {
if isState { if eventType.IsState {
return pl.StateDefault() return pl.StateDefault()
} }
return pl.EventsDefault return pl.EventsDefault
@ -271,20 +287,20 @@ func (pl *PowerLevels) GetEventLevel(eventType EventType, isState bool) int {
return level return level
} }
func (pl *PowerLevels) SetEventLevel(eventType EventType, isState bool, level int) { func (pl *PowerLevels) SetEventLevel(eventType EventType, level int) {
pl.eventsLock.Lock() pl.eventsLock.Lock()
defer pl.eventsLock.Unlock() defer pl.eventsLock.Unlock()
if (isState && level == pl.StateDefault()) || (!isState && level == pl.EventsDefault) { if (eventType.IsState && level == pl.StateDefault()) || (!eventType.IsState && level == pl.EventsDefault) {
delete(pl.Events, eventType) delete(pl.Events, eventType)
} else { } else {
pl.Events[eventType] = level pl.Events[eventType] = level
} }
} }
func (pl *PowerLevels) EnsureEventLevel(eventType EventType, isState bool, level int) bool { func (pl *PowerLevels) EnsureEventLevel(eventType EventType, level int) bool {
existingLevel := pl.GetEventLevel(eventType, isState) existingLevel := pl.GetEventLevel(eventType)
if existingLevel != level { if existingLevel != level {
pl.SetEventLevel(eventType, isState, level) pl.SetEventLevel(eventType, level)
return true return true
} }
return false return false

View file

@ -2,17 +2,19 @@ package appservice
import ( import (
"fmt" "fmt"
"html/template"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"maunium.net/go/maulogger"
"strings"
"net/http"
"errors" "errors"
"maunium.net/go/gomatrix" "maunium.net/go/gomatrix"
"maunium.net/go/maulogger"
"net/http"
"regexp" "regexp"
"strings"
) )
// EventChannelSize is the size for the Events channel in Appservice instances. // EventChannelSize is the size for the Events channel in Appservice instances.
@ -263,15 +265,24 @@ func CreateLogConfig() LogConfig {
} }
} }
type FileFormatData struct {
Date string
Index int
}
// GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct. // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat { func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
path := lc.FileNameFormat os.MkdirAll(lc.Directory, 0700)
if len(lc.Directory) > 0 { path := filepath.Join(lc.Directory, lc.FileNameFormat)
path = lc.Directory + "/" + path tpl, _ := template.New("fileformat").Parse(path)
}
return func(now string, i int) string { return func(now string, i int) string {
return fmt.Sprintf(path, now, i) var buf strings.Builder
tpl.Execute(&buf, FileFormatData{
Date: now,
Index: i,
})
return buf.String()
} }
} }

View file

@ -201,19 +201,19 @@ func (intent *IntentAPI) RedactEvent(roomID, eventID string, req *gomatrix.ReqRe
} }
func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) { func (intent *IntentAPI) SetRoomName(roomID, roomName string) (*gomatrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, "m.room.name", "", map[string]interface{}{ return intent.SendStateEvent(roomID, gomatrix.StateRoomName, "", map[string]interface{}{
"name": roomName, "name": roomName,
}) })
} }
func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) { func (intent *IntentAPI) SetRoomAvatar(roomID, avatarURL string) (*gomatrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, "m.room.avatar", "", map[string]interface{}{ return intent.SendStateEvent(roomID, gomatrix.StateRoomAvatar, "", map[string]interface{}{
"url": avatarURL, "url": avatarURL,
}) })
} }
func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) { func (intent *IntentAPI) SetRoomTopic(roomID, topic string) (*gomatrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, "m.room.topic", "", map[string]interface{}{ return intent.SendStateEvent(roomID, gomatrix.StateTopic, "", map[string]interface{}{
"topic": topic, "topic": topic,
}) })
} }

View file

@ -15,13 +15,15 @@ type StateStore interface {
SetTyping(roomID, userID string, timeout int64) SetTyping(roomID, userID string, timeout int64)
IsInRoom(roomID, userID string) bool IsInRoom(roomID, userID string) bool
IsInvited(roomID, userID string) bool
IsMembership(roomID, userID string, allowedMemberships ...string) bool
SetMembership(roomID, userID, membership string) SetMembership(roomID, userID, membership string)
SetPowerLevels(roomID string, levels *gomatrix.PowerLevels) SetPowerLevels(roomID string, levels *gomatrix.PowerLevels)
GetPowerLevels(roomID string) *gomatrix.PowerLevels GetPowerLevels(roomID string) *gomatrix.PowerLevels
GetPowerLevel(roomID, userID string) int GetPowerLevel(roomID, userID string) int
GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int
HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool
} }
func (as *AppService) UpdateState(evt *gomatrix.Event) { func (as *AppService) UpdateState(evt *gomatrix.Event) {
@ -126,7 +128,21 @@ func (store *BasicStateStore) GetMembership(roomID, userID string) string {
} }
func (store *BasicStateStore) IsInRoom(roomID, userID string) bool { func (store *BasicStateStore) IsInRoom(roomID, userID string) bool {
return store.GetMembership(roomID, userID) == "join" return store.IsMembership(roomID, userID, "join")
}
func (store *BasicStateStore) IsInvited(roomID, userID string) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *BasicStateStore) IsMembership(roomID, userID string, allowedMemberships ...string) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
} }
func (store *BasicStateStore) SetMembership(roomID, userID, membership string) { func (store *BasicStateStore) SetMembership(roomID, userID, membership string) {
@ -160,19 +176,10 @@ func (store *BasicStateStore) GetPowerLevel(roomID, userID string) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID) return store.GetPowerLevels(roomID).GetUserLevel(userID)
} }
func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType, isState bool) int { func (store *BasicStateStore) GetPowerLevelRequirement(roomID string, eventType gomatrix.EventType) int {
levels := store.GetPowerLevels(roomID) return store.GetPowerLevels(roomID).GetEventLevel(eventType)
switch eventType {
case "kick":
return levels.Kick()
case "invite":
return levels.Invite()
case "redact":
return levels.Redact()
}
return levels.GetEventLevel(eventType, isState)
} }
func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType, isState bool) bool { func (store *BasicStateStore) HasPowerLevel(roomID, userID string, eventType gomatrix.EventType) bool {
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType, isState) return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
} }