Move a bunch of stuff to mautrix-go

See d578d1a610

Database upgrades from before v0.4.0 were squashed, users must update
to at least v0.4.0 before updating beyond this commit.
This commit is contained in:
Tulir Asokan 2022-05-22 01:06:30 +03:00
parent 42a4839a4e
commit a948ea0146
83 changed files with 627 additions and 2838 deletions

View file

@ -114,18 +114,18 @@ func (pong *BridgeState) shouldDeduplicate(newPong *BridgeState) bool {
return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix() return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix()
} }
func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) error { func (br *WABridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
var body bytes.Buffer var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(&state); err != nil { if err := json.NewEncoder(&body).Encode(&state); err != nil {
return fmt.Errorf("failed to encode bridge state JSON: %w", err) return fmt.Errorf("failed to encode bridge state JSON: %w", err)
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, bridge.Config.Homeserver.StatusEndpoint, &body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, br.Config.Homeserver.StatusEndpoint, &body)
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare request: %w", err) return fmt.Errorf("failed to prepare request: %w", err)
} }
req.Header.Set("Authorization", "Bearer "+bridge.Config.AppService.ASToken) req.Header.Set("Authorization", "Bearer "+br.Config.AppService.ASToken)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
@ -143,17 +143,17 @@ func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) e
return nil return nil
} }
func (bridge *Bridge) sendGlobalBridgeState(state BridgeState) { func (br *WABridge) sendGlobalBridgeState(state BridgeState) {
if len(bridge.Config.Homeserver.StatusEndpoint) == 0 { if len(br.Config.Homeserver.StatusEndpoint) == 0 {
return return
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
if err := bridge.sendBridgeState(ctx, &state); err != nil { if err := br.sendBridgeState(ctx, &state); err != nil {
bridge.Log.Warnln("Failed to update global bridge state:", err) br.Log.Warnln("Failed to update global bridge state:", err)
} else { } else {
bridge.Log.Debugfln("Sent new global bridge state %+v", state) br.Log.Debugfln("Sent new global bridge state %+v", state)
} }
} }

View file

@ -39,6 +39,7 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@ -47,12 +48,12 @@ import (
) )
type CommandHandler struct { type CommandHandler struct {
bridge *Bridge bridge *WABridge
log maulogger.Logger log maulogger.Logger
} }
// NewCommandHandler creates a CommandHandler // NewCommandHandler creates a CommandHandler
func NewCommandHandler(bridge *Bridge) *CommandHandler { func NewCommandHandler(bridge *WABridge) *CommandHandler {
return &CommandHandler{ return &CommandHandler{
bridge: bridge, bridge: bridge,
log: bridge.Log.Sub("Command handler"), log: bridge.Log.Sub("Command handler"),
@ -62,7 +63,7 @@ func NewCommandHandler(bridge *Bridge) *CommandHandler {
// CommandEvent stores all data which might be used to handle commands // CommandEvent stores all data which might be used to handle commands
type CommandEvent struct { type CommandEvent struct {
Bot *appservice.IntentAPI Bot *appservice.IntentAPI
Bridge *Bridge Bridge *WABridge
Portal *Portal Portal *Portal
Handler *CommandHandler Handler *CommandHandler
RoomID id.RoomID RoomID id.RoomID
@ -251,13 +252,7 @@ func (handler *CommandHandler) CommandDevTest(_ *CommandEvent) {
const cmdVersionHelp = `version - View the bridge version` const cmdVersionHelp = `version - View the bridge version`
func (handler *CommandHandler) CommandVersion(ce *CommandEvent) { func (handler *CommandHandler) CommandVersion(ce *CommandEvent) {
linkifiedVersion := fmt.Sprintf("v%s", Version) ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, BuildTime))
if Tag == Version {
linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", Version, URL, Tag)
} else if len(Commit) > 8 {
linkifiedVersion = strings.Replace(linkifiedVersion, Commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", Commit[:8], URL, Commit), 1)
}
ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", Name, URL, linkifiedVersion, BuildTime))
} }
const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.` const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.`
@ -331,7 +326,7 @@ func (handler *CommandHandler) CommandJoin(ce *CommandEvent) {
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid) ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
} }
func tryDecryptEvent(crypto Crypto, evt *event.Event) (json.RawMessage, error) { func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
var data json.RawMessage var data json.RawMessage
if evt.Type != event.EventEncrypted { if evt.Type != event.EventEncrypted {
data = evt.Content.VeryRaw data = evt.Content.VeryRaw
@ -903,7 +898,7 @@ func matchesQuery(str string, query string) bool {
return strings.Contains(strings.ToLower(str), query) return strings.Contains(strings.ToLower(str), query)
} }
func formatContacts(bridge *Bridge, input map[types.JID]types.ContactInfo, query string) (result []string) { func formatContacts(bridge *WABridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
hasQuery := len(query) > 0 hasQuery := len(query) > 0
for jid, contact := range input { for jid, contact := range input {
if len(contact.FullName) == 0 { if len(contact.FullName) == 0 {

View file

@ -24,6 +24,7 @@ import (
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@ -118,23 +119,23 @@ type BridgeConfig struct {
AdditionalHelp string `yaml:"additional_help"` AdditionalHelp string `yaml:"additional_help"`
} `yaml:"management_room_text"` } `yaml:"management_room_text"`
Encryption struct { Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
Allow bool `yaml:"allow"`
Default bool `yaml:"default"`
KeySharing struct { Provisioning struct {
Allow bool `yaml:"allow"` Prefix string `yaml:"prefix"`
RequireCrossSigning bool `yaml:"require_cross_signing"` SharedSecret string `yaml:"shared_secret"`
RequireVerification bool `yaml:"require_verification"` } `yaml:"provisioning"`
} `yaml:"key_sharing"`
} `yaml:"encryption"`
Permissions PermissionConfig `yaml:"permissions"` Permissions PermissionConfig `yaml:"permissions"`
Relay RelaybotConfig `yaml:"relay"` Relay RelaybotConfig `yaml:"relay"`
usernameTemplate *template.Template `yaml:"-"` ParsedUsernameTemplate *template.Template `yaml:"-"`
displaynameTemplate *template.Template `yaml:"-"` displaynameTemplate *template.Template `yaml:"-"`
}
func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig {
return bc.Encryption
} }
type umBridgeConfig BridgeConfig type umBridgeConfig BridgeConfig
@ -145,7 +146,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return err return err
} }
bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate) bc.ParsedUsernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
if err != nil { if err != nil {
return err return err
} else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") { } else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") {
@ -206,7 +207,7 @@ func (bc BridgeConfig) FormatDisplayname(jid types.JID, contact types.ContactInf
func (bc BridgeConfig) FormatUsername(username string) string { func (bc BridgeConfig) FormatUsername(username string) string {
var buf strings.Builder var buf strings.Builder
_ = bc.usernameTemplate.Execute(&buf, username) _ = bc.ParsedUsernameTemplate.Execute(&buf, username)
return buf.String() return buf.String()
} }

View file

@ -17,52 +17,12 @@
package config package config
import ( import (
"fmt" "maunium.net/go/mautrix/bridge/bridgeconfig"
"gopkg.in/yaml.v3"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
var ExampleConfig string
type Config struct { type Config struct {
Homeserver struct { *bridgeconfig.BaseConfig `yaml:",inline"`
Address string `yaml:"address"`
Domain string `yaml:"domain"`
Asmux bool `yaml:"asmux"`
StatusEndpoint string `yaml:"status_endpoint"`
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
AsyncMedia bool `yaml:"async_media"`
} `yaml:"homeserver"`
AppService struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
Port uint16 `yaml:"port"`
Database DatabaseConfig `yaml:"database"`
Provisioning struct {
Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
} `yaml:"provisioning"`
ID string `yaml:"id"`
Bot struct {
Username string `yaml:"username"`
Displayname string `yaml:"displayname"`
Avatar string `yaml:"avatar"`
ParsedAvatar id.ContentURI `yaml:"-"`
} `yaml:"bot"`
EphemeralEvents bool `yaml:"ephemeral_events"`
ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
} `yaml:"appservice"`
SegmentKey string `yaml:"segment_key"` SegmentKey string `yaml:"segment_key"`
@ -77,8 +37,6 @@ type Config struct {
} `yaml:"whatsapp"` } `yaml:"whatsapp"`
Bridge BridgeConfig `yaml:"bridge"` Bridge BridgeConfig `yaml:"bridge"`
Logging appservice.LogConfig `yaml:"logging"`
} }
func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool { func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool {
@ -98,44 +56,3 @@ func (config *Config) CanDoublePuppetBackfill(userID id.UserID) bool {
} }
return true return true
} }
func Load(data []byte, upgraded bool) (*Config, error) {
var config = &Config{}
if !upgraded {
// Fallback: if config upgrading failed, load example config for base values
err := yaml.Unmarshal([]byte(ExampleConfig), config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal example config: %w", err)
}
}
err := yaml.Unmarshal(data, config)
if err != nil {
return nil, err
}
return config, err
}
func (config *Config) MakeAppService() (*appservice.AppService, error) {
as := appservice.Create()
as.HomeserverDomain = config.Homeserver.Domain
as.HomeserverURL = config.Homeserver.Address
as.Host.Hostname = config.AppService.Hostname
as.Host.Port = config.AppService.Port
as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
as.DefaultHTTPRetries = 4
var err error
as.Registration, err = config.GetRegistration()
return as, err
}
type DatabaseConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}

View file

@ -1,82 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"fmt"
"regexp"
"strings"
"maunium.net/go/mautrix/appservice"
)
func (config *Config) NewRegistration() (*appservice.Registration, error) {
registration := appservice.CreateRegistration()
err := config.copyToRegistration(registration)
if err != nil {
return nil, err
}
config.AppService.ASToken = registration.AppToken
config.AppService.HSToken = registration.ServerToken
// Workaround for https://github.com/matrix-org/synapse/pull/5758
registration.SenderLocalpart = appservice.RandomString(32)
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
regexp.QuoteMeta(config.AppService.Bot.Username),
regexp.QuoteMeta(config.Homeserver.Domain)))
registration.Namespaces.RegisterUserIDs(botRegex, true)
return registration, nil
}
func (config *Config) GetRegistration() (*appservice.Registration, error) {
registration := appservice.CreateRegistration()
err := config.copyToRegistration(registration)
if err != nil {
return nil, err
}
registration.AppToken = config.AppService.ASToken
registration.ServerToken = config.AppService.HSToken
return registration, nil
}
func (config *Config) copyToRegistration(registration *appservice.Registration) error {
registration.ID = config.AppService.ID
registration.URL = config.AppService.Address
falseVal := false
registration.RateLimited = &falseVal
registration.SenderLocalpart = config.AppService.Bot.Username
registration.EphemeralEvents = config.AppService.EphemeralEvents
usernamePlaceholder := appservice.RandomString(16)
usernameTemplate := fmt.Sprintf("@%s:%s",
config.Bridge.FormatUsername(usernamePlaceholder),
config.Homeserver.Domain)
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
userIDRegex, err := regexp.Compile(usernameTemplate)
if err != nil {
return err
}
registration.Namespaces.RegisterUserIDs(userIDRegex, true)
return nil
}

View file

@ -20,50 +20,12 @@ import (
"strings" "strings"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge/bridgeconfig"
up "maunium.net/go/mautrix/util/configupgrade" up "maunium.net/go/mautrix/util/configupgrade"
) )
type waUpgrader struct{} func DoUpgrade(helper *up.Helper) {
bridgeconfig.Upgrader.DoUpgrade(helper)
func (wau waUpgrader) GetBase() string {
return ExampleConfig
}
func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "homeserver", "address")
helper.Copy(up.Str, "homeserver", "domain")
helper.Copy(up.Bool, "homeserver", "asmux")
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
helper.Copy(up.Bool, "homeserver", "async_media")
helper.Copy(up.Str, "appservice", "address")
helper.Copy(up.Str, "appservice", "hostname")
helper.Copy(up.Int, "appservice", "port")
helper.Copy(up.Str, "appservice", "database", "type")
helper.Copy(up.Str, "appservice", "database", "uri")
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok && strings.HasSuffix(prefix, "/v1") {
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "appservice", "provisioning", "prefix")
} else {
helper.Copy(up.Str, "appservice", "provisioning", "prefix")
}
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := appservice.RandomString(64)
helper.Set(up.Str, sharedSecret, "appservice", "provisioning", "shared_secret")
} else {
helper.Copy(up.Str, "appservice", "provisioning", "shared_secret")
}
helper.Copy(up.Str, "appservice", "id")
helper.Copy(up.Str, "appservice", "bot", "username")
helper.Copy(up.Str, "appservice", "bot", "displayname")
helper.Copy(up.Str, "appservice", "bot", "avatar")
helper.Copy(up.Bool, "appservice", "ephemeral_events")
helper.Copy(up.Str, "appservice", "as_token")
helper.Copy(up.Str, "appservice", "hs_token")
helper.Copy(up.Str|up.Null, "segment_key") helper.Copy(up.Str|up.Null, "segment_key")
@ -134,46 +96,41 @@ func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow") helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow")
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing") helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing")
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification") helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification")
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok {
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "bridge", "provisioning", "prefix")
} else {
helper.Copy(up.Str, "bridge", "provisioning", "prefix")
}
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); ok && secret != "generate" {
helper.Set(up.Str, secret, "bridge", "provisioning", "shared_secret")
} else if secret, ok = helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := appservice.RandomString(64)
helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret")
} else {
helper.Copy(up.Str, "bridge", "provisioning", "shared_secret")
}
helper.Copy(up.Map, "bridge", "permissions") helper.Copy(up.Map, "bridge", "permissions")
helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "enabled")
helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.Bool, "bridge", "relay", "admin_only")
helper.Copy(up.Map, "bridge", "relay", "message_formats") helper.Copy(up.Map, "bridge", "relay", "message_formats")
helper.Copy(up.Str, "logging", "directory")
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
helper.Copy(up.Int, "logging", "file_mode")
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
helper.Copy(up.Str, "logging", "print_level")
} }
func (wau waUpgrader) SpacedBlocks() [][]string { var SpacedBlocks = [][]string{
return [][]string{ {"homeserver", "asmux"},
{"homeserver", "asmux"}, {"appservice"},
{"appservice"}, {"appservice", "hostname"},
{"appservice", "hostname"}, {"appservice", "database"},
{"appservice", "database"}, {"appservice", "id"},
{"appservice", "provisioning"}, {"appservice", "as_token"},
{"appservice", "id"}, {"segment_key"},
{"appservice", "as_token"}, {"metrics"},
{"segment_key"}, {"whatsapp"},
{"metrics"}, {"bridge"},
{"whatsapp"}, {"bridge", "command_prefix"},
{"bridge"}, {"bridge", "management_room_text"},
{"bridge", "command_prefix"}, {"bridge", "encryption"},
{"bridge", "management_room_text"}, {"bridge", "provisioning"},
{"bridge", "encryption"}, {"bridge", "permissions"},
{"bridge", "permissions"}, {"bridge", "relay"},
{"bridge", "relay"}, {"logging"},
{"logging"},
}
}
func Mutate(path string, mutate func(helper *up.Helper)) error {
_, _, err := up.Do(path, true, waUpgrader{}, up.SimpleUpgrader(mutate))
return err
}
func Upgrade(path string, save bool) ([]byte, bool, error) {
return up.Do(path, save, waUpgrader{})
} }

327
crypto.go
View file

@ -1,327 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build cgo && !nocrypto
package main
import (
"fmt"
"runtime/debug"
"time"
"github.com/lib/pq"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database"
)
var NoSessionFound = crypto.NoSessionFound
var levelTrace = maulogger.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
type CryptoHelper struct {
bridge *Bridge
client *mautrix.Client
mach *crypto.OlmMachine
store *database.SQLCryptoStore
log maulogger.Logger
baseLog maulogger.Logger
}
func init() {
crypto.PostgresArrayWrapper = pq.Array
}
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
return nil
}
baseLog := bridge.Log.Sub("Crypto")
return &CryptoHelper{
bridge: bridge,
log: baseLog.Sub("Helper"),
baseLog: baseLog,
}
}
func (helper *CryptoHelper) Init() error {
helper.log.Debugln("Initializing end-to-bridge encryption...")
helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
var err error
helper.client, err = helper.loginBot()
if err != nil {
return err
}
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
logger := &cryptoLogger{helper.baseLog}
stateStore := &cryptoStateStore{helper.bridge}
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
helper.mach.AllowKeyShare = helper.allowKeyShare
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = &cryptoClientStore{helper.store}
return helper.mach.Load()
}
func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
cfg := helper.bridge.Config.Bridge.Encryption.KeySharing
if !cfg.Allow {
return &crypto.KeyShareRejectNoResponse
} else if device.Trust == crypto.TrustStateBlacklisted {
return &crypto.KeyShareRejectBlacklisted
} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
portal := helper.bridge.GetPortalByMXID(info.RoomID)
if portal == nil {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
}
user := helper.bridge.GetUserByMXID(device.UserID)
// FIXME reimplement IsInPortal
if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
}
helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
return nil
} else {
return &crypto.KeyShareRejectUnverified
}
}
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
}
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
if err != nil {
return nil, fmt.Errorf("failed to initialize client: %w", err)
}
client.Logger = helper.baseLog.Sub("Bot")
client.Client = helper.bridge.AS.HTTPClient
client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
flows, err := client.GetLoginFlows()
if err != nil {
return nil, fmt.Errorf("failed to get supported login flows: %w", err)
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
return nil, fmt.Errorf("homeserver does not support appservice login")
}
// We set the API token to the AS token here to authenticate the appservice login
// It'll get overridden after the login
client.AccessToken = helper.bridge.AS.Registration.AppToken
resp, err := client.Login(&mautrix.ReqLogin{
Type: mautrix.AuthTypeAppservice,
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID())},
DeviceID: deviceID,
InitialDeviceDisplayName: "WhatsApp Bridge",
StoreCredentials: true,
})
if err != nil {
return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
}
helper.store.DeviceID = resp.DeviceID
return client, nil
}
func (helper *CryptoHelper) Start() {
helper.log.Debugln("Starting syncer for receiving to-device messages")
err := helper.client.Sync()
if err != nil {
helper.log.Errorln("Fatal error syncing:", err)
} else {
helper.log.Infoln("Bridge bot to-device syncer stopped without error")
}
}
func (helper *CryptoHelper) Stop() {
helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
helper.client.StopSync()
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return nil, err
}
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
users, err := helper.store.GetRoomMembers(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room member list: %w", err)
}
err = helper.mach.ShareGroupSession(roomID, users)
if err != nil {
return nil, fmt.Errorf("failed to share group session: %w", err)
}
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return encrypted, nil
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
if err != nil {
helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
} else {
helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
}
}
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
if err != nil {
helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
}
}
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
helper.mach.HandleMemberEvent(evt)
}
type cryptoSyncer struct {
*crypto.OlmMachine
}
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
done := make(chan struct{})
go func() {
defer func() {
if err := recover(); err != nil {
syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
}
done <- struct{}{}
}()
syncer.Log.Trace("Starting sync response handling (%s)", since)
syncer.ProcessSyncResponse(resp, since)
syncer.Log.Trace("Successfully handled sync response (%s)", since)
}()
select {
case <-done:
case <-time.After(30 * time.Second):
syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
}
return nil
}
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
return 10 * time.Second, nil
}
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
everything := []event.Type{{Type: "*"}}
return &mautrix.Filter{
Presence: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
Room: mautrix.RoomFilter{
IncludeLeave: false,
Ephemeral: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
State: mautrix.FilterPart{NotTypes: everything},
Timeline: mautrix.FilterPart{NotTypes: everything},
},
}
}
type cryptoLogger struct {
int maulogger.Logger
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}
type cryptoClientStore struct {
int *database.SQLCryptoStore
}
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
c.int.PutNextBatch(nextBatchToken)
}
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
return c.int.GetNextBatch()
}
var _ mautrix.Storer = (*cryptoClientStore)(nil)
type cryptoStateStore struct {
bridge *Bridge
}
var _ crypto.StateStore = (*cryptoStateStore)(nil)
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
portal := c.bridge.GetPortalByMXID(id)
if portal != nil {
return portal.Encrypted
}
return false
}
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
return c.bridge.StateStore.FindSharedRooms(id)
}
func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
// TODO implement
return nil
}

View file

@ -75,8 +75,8 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
Type: mautrix.AuthTypePassword, Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)},
Password: hex.EncodeToString(mac.Sum(nil)), Password: hex.EncodeToString(mac.Sum(nil)),
DeviceID: "WhatsApp Bridge", DeviceID: "WhatsApp bridge",
InitialDeviceDisplayName: "WhatsApp Bridge", InitialDeviceDisplayName: "WhatsApp bridge",
}) })
if err != nil { if err != nil {
return "", err return "", err
@ -84,22 +84,22 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
return resp.AccessToken, nil return resp.AccessToken, nil
} }
func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { func (br *WABridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
_, homeserver, err := mxid.Parse() _, homeserver, err := mxid.Parse()
if err != nil { if err != nil {
return nil, err return nil, err
} }
homeserverURL, found := bridge.Config.Bridge.DoublePuppetServerMap[homeserver] homeserverURL, found := br.Config.Bridge.DoublePuppetServerMap[homeserver]
if !found { if !found {
if homeserver == bridge.AS.HomeserverDomain { if homeserver == br.AS.HomeserverDomain {
homeserverURL = bridge.AS.HomeserverURL homeserverURL = br.AS.HomeserverURL
} else if bridge.Config.Bridge.DoublePuppetAllowDiscovery { } else if br.Config.Bridge.DoublePuppetAllowDiscovery {
resp, err := mautrix.DiscoverClientAPI(homeserver) resp, err := mautrix.DiscoverClientAPI(homeserver)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err)
} }
homeserverURL = resp.Homeserver.BaseURL homeserverURL = resp.Homeserver.BaseURL
bridge.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid) br.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
} else { } else {
return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver)
} }
@ -108,9 +108,9 @@ func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.Logger = bridge.AS.Log.Sub(mxid.String()) client.Logger = br.AS.Log.Sub(mxid.String())
client.Client = bridge.AS.HTTPClient client.Client = br.AS.HTTPClient
client.DefaultHTTPRetries = bridge.AS.DefaultHTTPRetries client.DefaultHTTPRetries = br.AS.DefaultHTTPRetries
return client, nil return client, nil
} }

View file

@ -26,7 +26,9 @@ import (
"time" "time"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
) )
type BackfillType int type BackfillType int
@ -165,7 +167,7 @@ func (b *Backfill) String() string {
) )
} }
func (b *Backfill) Scan(row Scannable) *Backfill { func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay) err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
@ -256,7 +258,7 @@ type BackfillState struct {
FirstExpectedTimestamp uint64 FirstExpectedTimestamp uint64
} }
func (b *BackfillState) Scan(row Scannable) *BackfillState { func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp) err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {

View file

@ -1,18 +1,8 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // Copyright (c) 2022 Tulir Asokan
// Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This Source Code Form is subject to the terms of the Mozilla Public
// it under the terms of the GNU Affero General Public License as published by // License, v. 2.0. If a copy of the MPL was not distributed with this
// the Free Software Foundation, either version 3 of the License, or // file, You can obtain one at http://mozilla.org/MPL/2.0/.
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build cgo && !nocrypto //go:build cgo && !nocrypto
@ -21,8 +11,6 @@ package database
import ( import (
"database/sql" "database/sql"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@ -37,11 +25,9 @@ var _ crypto.Store = (*SQLCryptoStore)(nil)
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore { func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
return &SQLCryptoStore{ return &SQLCryptoStore{
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "", SQLCryptoStore: crypto.NewSQLCryptoStore(db.Database, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
[]byte("maunium.net/go/mautrix-whatsapp"), UserID: userID,
&cryptoLogger{db.log.Sub("CryptoStore")}), GhostIDFormat: ghostIDFormat,
UserID: userID,
GhostIDFormat: ghostIDFormat,
} }
} }
@ -76,30 +62,3 @@ func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.User
} }
return return
} }
// TODO merge this with the one in the parent package
type cryptoLogger struct {
int log.Logger
}
var levelTrace = log.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}

View file

@ -17,21 +17,17 @@
package database package database
import ( import (
"database/sql"
"errors" "errors"
"fmt"
"net" "net"
"time" "time"
"github.com/lib/pq" "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2"
"go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/store/sqlstore"
"maunium.net/go/mautrix-whatsapp/config"
"maunium.net/go/mautrix-whatsapp/database/upgrades" "maunium.net/go/mautrix-whatsapp/database/upgrades"
"maunium.net/go/mautrix/util/dbutil"
) )
func init() { func init() {
@ -39,9 +35,7 @@ func init() {
} }
type Database struct { type Database struct {
*sql.DB *dbutil.Database
log log.Logger
dialect string
User *UserQuery User *UserQuery
Portal *PortalQuery Portal *PortalQuery
@ -55,79 +49,46 @@ type Database struct {
MediaBackfillRequest *MediaBackfillRequestQuery MediaBackfillRequest *MediaBackfillRequestQuery
} }
func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { func New(baseDB *dbutil.Database) *Database {
conn, err := sql.Open(cfg.Type, cfg.URI) db := &Database{Database: baseDB}
if err != nil { db.UpgradeTable = upgrades.Table
return nil, err
}
db := &Database{
DB: conn,
log: baseLog.Sub("Database"),
dialect: cfg.Type,
}
db.User = &UserQuery{ db.User = &UserQuery{
db: db, db: db,
log: db.log.Sub("User"), log: db.Log.Sub("User"),
} }
db.Portal = &PortalQuery{ db.Portal = &PortalQuery{
db: db, db: db,
log: db.log.Sub("Portal"), log: db.Log.Sub("Portal"),
} }
db.Puppet = &PuppetQuery{ db.Puppet = &PuppetQuery{
db: db, db: db,
log: db.log.Sub("Puppet"), log: db.Log.Sub("Puppet"),
} }
db.Message = &MessageQuery{ db.Message = &MessageQuery{
db: db, db: db,
log: db.log.Sub("Message"), log: db.Log.Sub("Message"),
} }
db.Reaction = &ReactionQuery{ db.Reaction = &ReactionQuery{
db: db, db: db,
log: db.log.Sub("Reaction"), log: db.Log.Sub("Reaction"),
} }
db.DisappearingMessage = &DisappearingMessageQuery{ db.DisappearingMessage = &DisappearingMessageQuery{
db: db, db: db,
log: db.log.Sub("DisappearingMessage"), log: db.Log.Sub("DisappearingMessage"),
} }
db.Backfill = &BackfillQuery{ db.Backfill = &BackfillQuery{
db: db, db: db,
log: db.log.Sub("Backfill"), log: db.Log.Sub("Backfill"),
} }
db.HistorySync = &HistorySyncQuery{ db.HistorySync = &HistorySyncQuery{
db: db, db: db,
log: db.log.Sub("HistorySync"), log: db.Log.Sub("HistorySync"),
} }
db.MediaBackfillRequest = &MediaBackfillRequestQuery{ db.MediaBackfillRequest = &MediaBackfillRequestQuery{
db: db, db: db,
log: db.log.Sub("MediaBackfillRequest"), log: db.Log.Sub("MediaBackfillRequest"),
} }
return db
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
db.SetConnMaxIdleTime(maxIdleTimeDuration)
}
if len(cfg.ConnMaxLifetime) > 0 {
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
db.SetConnMaxLifetime(maxLifetimeDuration)
}
return db, nil
}
func (db *Database) Init() error {
return upgrades.Run(db.log.Sub("Upgrade"), db.dialect, db.DB)
}
type Scannable interface {
Scan(...interface{}) error
} }
func isRetryableError(err error) bool { func isRetryableError(err error) bool {
@ -145,7 +106,7 @@ func isRetryableError(err error) bool {
} }
func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) { func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) {
if db.dialect != "sqlite" && isRetryableError(err) { if db.Dialect != dbutil.SQLite && isRetryableError(err) {
sleepTime := time.Duration(attemptIndex*2) * time.Second sleepTime := time.Duration(attemptIndex*2) * time.Second
device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime) device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime)
time.Sleep(sleepTime) time.Sleep(sleepTime)

View file

@ -24,6 +24,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
) )
type DisappearingMessageQuery struct { type DisappearingMessageQuery struct {
@ -94,7 +95,7 @@ type DisappearingMessage struct {
ExpireAt time.Time ExpireAt time.Time
} }
func (msg *DisappearingMessage) Scan(row Scannable) *DisappearingMessage { func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
var expireIn int64 var expireIn int64
var expireAt sql.NullInt64 var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt) err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)

View file

@ -27,7 +27,9 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
) )
type HistorySyncQuery struct { type HistorySyncQuery struct {
@ -139,7 +141,7 @@ func (hsc *HistorySyncConversation) Upsert() {
} }
} }
func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation { func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
err := row.Scan( err := row.Scan(
&hsc.UserID, &hsc.UserID,
&hsc.ConversationID, &hsc.ConversationID,
@ -166,7 +168,7 @@ func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation
func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) { func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
nPtr := &n nPtr := &n
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit. // Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
if n < 0 && hsq.db.dialect == "postgres" { if n < 0 && hsq.db.Dialect == dbutil.Postgres {
nPtr = nil nPtr = nil
} }
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr) rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)

View file

@ -22,7 +22,9 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
) )
type MediaBackfillRequestStatus int type MediaBackfillRequestStatus int
@ -100,7 +102,7 @@ func (mbr *MediaBackfillRequest) Upsert() {
} }
} }
func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest { func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error) err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {

View file

@ -25,6 +25,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
) )
@ -163,7 +164,7 @@ func (msg *Message) IsFakeJID() bool {
return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID) return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
} }
func (msg *Message) Scan(row Scannable) *Message { func (msg *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil { if err != nil {

View file

@ -22,6 +22,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
) )
@ -152,7 +153,7 @@ type Portal struct {
ExpirationTime uint32 ExpirationTime uint32
} }
func (portal *Portal) Scan(row Scannable) *Portal { func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime) err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
if err != nil { if err != nil {

View file

@ -20,7 +20,9 @@ import (
"database/sql" "database/sql"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
) )
@ -97,7 +99,7 @@ type Puppet struct {
EnableReceipts bool EnableReceipts bool
} }
func (puppet *Puppet) Scan(row Scannable) *Puppet { func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality sql.NullInt64 var quality sql.NullInt64
var enablePresence, enableReceipts sql.NullBool var enablePresence, enableReceipts sql.NullBool

View file

@ -23,6 +23,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
) )
@ -85,7 +86,7 @@ type Reaction struct {
JID types.MessageID JID types.MessageID
} }
func (reaction *Reaction) Scan(row Scannable) *Reaction { func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID) err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {

View file

@ -1,282 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"encoding/json"
"errors"
"sync"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type SQLStateStore struct {
*appservice.TypingStateStore
db *Database
log log.Logger
Typing map[id.RoomID]map[id.UserID]int64
typingLock sync.RWMutex
}
var _ appservice.StateStore = (*SQLStateStore)(nil)
func NewSQLStateStore(db *Database) *SQLStateStore {
return &SQLStateStore{
TypingStateStore: appservice.NewTypingStateStore(),
db: db,
log: db.log.Sub("StateStore"),
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
var isRegistered bool
err := store.db.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
}
return isRegistered
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
}
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
if err != nil {
return members
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.db.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
}
return membership
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
}
return member
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
var member event.MemberEventContent
err := store.db.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
}
return &member, err == nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
rows, err := store.db.Query(`
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
WHERE user_id=$1 AND portal.encrypted=true
`, userID)
if err != nil {
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
return
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.log.Warnfln("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.db.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.db.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.db.Exec(`
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.db.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
}
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
return nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
if store.db.dialect == "postgres" {
var powerLevel int
err := store.db.
QueryRow(`
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
if store.db.dialect == "postgres" {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var powerLevel int
err := store.db.
QueryRow(`
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
if store.db.dialect == "postgres" {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var hasPower bool
err := store.db.
QueryRow(`SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
}
return hasPower
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View file

@ -0,0 +1,181 @@
-- v0 -> v48: Latest revision
CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
username TEXT UNIQUE,
agent SMALLINT,
device SMALLINT,
management_room TEXT,
space_room TEXT,
phone_last_seen BIGINT,
phone_last_pinged BIGINT,
timezone TEXT
);
CREATE TABLE portal (
jid TEXT,
receiver TEXT,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT,
encrypted BOOLEAN NOT NULL DEFAULT false,
first_event_id TEXT,
next_batch_id TEXT,
relay_user_id TEXT,
expiration_time BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (jid, receiver)
);
CREATE TABLE puppet (
username TEXT PRIMARY KEY,
displayname TEXT,
name_quality SMALLINT,
avatar TEXT,
avatar_url TEXT,
custom_mxid TEXT,
access_token TEXT,
next_batch TEXT,
enable_presence BOOLEAN NOT NULL DEFAULT true,
enable_receipts BOOLEAN NOT NULL DEFAULT true
);
-- only: postgres
CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found');
CREATE TABLE message (
chat_jid TEXT,
chat_receiver TEXT,
jid TEXT,
mxid TEXT UNIQUE,
sender TEXT,
timestamp BIGINT,
sent BOOLEAN,
error error_type,
type TEXT,
broadcast_list_jid TEXT,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
);
CREATE TABLE reaction (
chat_jid TEXT,
chat_receiver TEXT,
target_jid TEXT,
sender TEXT,
mxid TEXT NOT NULL,
jid TEXT NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
FOREIGN KEY (chat_jid, chat_receiver, target_jid) REFERENCES message(chat_jid, chat_receiver, jid)
ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE disappearing_message (
room_id TEXT,
event_id TEXT,
expire_in BIGINT NOT NULL,
expire_at BIGINT,
PRIMARY KEY (room_id, event_id)
);
CREATE TABLE user_portal (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_read_ts BIGINT NOT NULL DEFAULT 0,
in_space BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE backfill_queue (
queue_id INTEGER PRIMARY KEY
-- only: postgres
GENERATED ALWAYS AS IDENTITY
,
user_mxid TEXT,
type INTEGER NOT NULL,
priority INTEGER NOT NULL,
portal_jid TEXT,
portal_receiver TEXT,
time_start TIMESTAMP,
dispatch_time TIMESTAMP,
completed_at TIMESTAMP,
batch_delay INTEGER,
max_batch_events INTEGER NOT NULL,
max_total_events INTEGER,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts TIMESTAMP,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
);
CREATE TABLE media_backfill_requests (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
event_id TEXT,
media_key bytea,
status INTEGER,
error TEXT,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE history_sync_conversation (
user_mxid TEXT,
conversation_id TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_message_timestamp TIMESTAMP,
archived BOOLEAN,
pinned INTEGER,
mute_end_time TIMESTAMP,
disappearing_mode INTEGER,
end_of_history_transfer_type INTEGER,
ephemeral_Expiration INTEGER,
marked_as_unread BOOLEAN,
unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE history_sync_message (
user_mxid TEXT,
conversation_id TEXT,
message_id TEXT,
timestamp TIMESTAMP,
data bytea,
inserted_time TIMESTAMP,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
);

View file

@ -1,67 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(255),
receiver VARCHAR(255),
mxid VARCHAR(255) UNIQUE,
name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL,
PRIMARY KEY (jid, receiver)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(255) PRIMARY KEY,
avatar VARCHAR(255),
displayname VARCHAR(255),
name_quality SMALLINT
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255),
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key bytea,
mac_key bytea
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content bytea NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,15 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[2] = upgrade{"Add timestamp column to messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,15 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[3] = upgrade{"Add last_connection column to users", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,23 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[5] = upgrade{"Add columns to store custom puppet info", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN custom_mxid VARCHAR(255)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN access_token VARCHAR(1023)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN next_batch VARCHAR(255)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[6] = upgrade{"Add user-portal mapping table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE user_portal (
user_jid VARCHAR(255),
portal_jid VARCHAR(255),
portal_receiver VARCHAR(255),
PRIMARY KEY (user_jid, portal_jid, portal_receiver),
FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
return err
}}
}

View file

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[8] = upgrade{"Add columns to store portal in filtering community meta", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_community BOOLEAN NOT NULL DEFAULT FALSE`)
return err
}}
}

View file

@ -1,39 +0,0 @@
package upgrades
import (
"database/sql"
"strings"
)
func init() {
userProfileTable := `CREATE TABLE mx_user_profile (
room_id VARCHAR(255),
user_id VARCHAR(255),
membership VARCHAR(15) NOT NULL,
PRIMARY KEY (room_id, user_id)
)`
roomStateTable := `CREATE TABLE mx_room_state (
room_id VARCHAR(255) PRIMARY KEY,
power_levels TEXT
)`
registrationsTable := `CREATE TABLE mx_registrations (
user_id VARCHAR(255) PRIMARY KEY
)`
upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == Postgres {
roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
}
if _, err := tx.Exec(userProfileTable); err != nil {
return err
} else if _, err = tx.Exec(roomStateTable); err != nil {
return err
} else if _, err = tx.Exec(registrationsTable); err != nil {
return err
}
return nil
}}
}

View file

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
return err
}}
}

View file

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[11] = upgrade{"Adjust the length of column topic in portal", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
return nil
}
_, err := tx.Exec(`ALTER TABLE portal ALTER COLUMN topic TYPE VARCHAR(512)`)
return err
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View file

@ -1,73 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_account (
device_id VARCHAR(255) PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id VARCHAR(255) NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
user_id VARCHAR(255) PRIMARY KEY
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_device (
user_id VARCHAR(255),
device_id VARCHAR(255),
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (user_id, device_id)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains bytea NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,25 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[15] = upgrade{"Add enable_presence column for puppets", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_presence BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[1](tx, c.dialect.String())
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[17] = upgrade{"Add enable_receipts column for puppets", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_receipts BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[18] = upgrade{"Add megolm withheld data to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[2](tx, c.dialect.String())
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[19] = upgrade{"Add cross-signing keys to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[3](tx, c.dialect.String())
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View file

@ -1,44 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[21] = upgrade{"Remove message content from local database", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
_, err := tx.Exec("ALTER TABLE message RENAME TO old_message")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid TEXT,
chat_receiver TEXT,
jid TEXT,
mxid TEXT NOT NULL UNIQUE,
sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
sent BOOLEAN NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO message SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM old_message")
return err
} else {
_, err := tx.Exec(`ALTER TABLE message DROP COLUMN content`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN timestamp DROP DEFAULT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN sent DROP DEFAULT`)
return err
}
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[23] = upgrade{"Replace VARCHAR(255) with TEXT in the crypto database", func(tx *sql.Tx, ctx context) error {
return sql_store_upgrade.Upgrades[4](tx, ctx.dialect.String())
}}
}

View file

@ -1,48 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[22] = upgrade{"Replace VARCHAR(255) with TEXT in the database", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't enforce varchar sizes anyway
return nil
}
return execMany(tx,
`ALTER TABLE message ALTER COLUMN chat_jid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN chat_receiver TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN sender TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN receiver TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN name TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN topic TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN avatar TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN avatar_url TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN avatar TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN displayname TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN custom_mxid TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN access_token TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN next_batch TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN avatar_url TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN management_room TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN client_id TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN client_token TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN server_token TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN user_jid TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN portal_jid TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN portal_receiver TYPE TEXT`,
)
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"go.mau.fi/whatsmeow/store/sqlstore"
)
func init() {
upgrades[24] = upgrade{"Add whatsmeow state store", func(tx *sql.Tx, ctx context) error {
return sqlstore.Upgrades[0](tx, sqlstore.NewWithDB(ctx.db, ctx.dialect.String(), nil))
}}
}

View file

@ -1,93 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[25] = upgrade{"Update things for multidevice", func(tx *sql.Tx, ctx context) error {
// This is probably not necessary
_, err := tx.Exec("DROP TABLE user_portal")
if err != nil {
return err
}
// Remove invalid puppet rows
_, err = tx.Exec("DELETE FROM puppet WHERE jid LIKE '%@g.us' OR jid LIKE '%@broadcast'")
if err != nil {
return err
}
// Remove the suffix from puppets since they'll all have the same suffix
_, err = tx.Exec("UPDATE puppet SET jid=REPLACE(jid, '@s.whatsapp.net', '')")
if err != nil {
return err
}
// Rename column to correctly represent the new content
_, err = tx.Exec("ALTER TABLE puppet RENAME COLUMN jid TO username")
if err != nil {
return err
}
if ctx.dialect == SQLite {
// Message content was removed from the main message table earlier, but the backup table still exists for SQLite
_, err = tx.Exec("DROP TABLE IF EXISTS old_message")
_, err = tx.Exec(`ALTER TABLE "user" RENAME TO old_user`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
username TEXT UNIQUE,
agent SMALLINT,
device SMALLINT,
management_room TEXT
)`)
if err != nil {
return err
}
// No need to copy auth data, users need to relogin anyway
_, err = tx.Exec(`INSERT INTO "user" (mxid, management_room) SELECT mxid, management_room FROM old_user`)
if err != nil {
return err
}
_, err = tx.Exec("DROP TABLE old_user")
if err != nil {
return err
}
} else {
// The jid column never actually contained the full JID, so let's rename it.
_, err = tx.Exec(`ALTER TABLE "user" RENAME COLUMN jid TO username`)
if err != nil {
return err
}
// The auth data is now in the whatsmeow_device table.
for _, column := range []string{"last_connection", "client_id", "client_token", "server_token", "enc_key", "mac_key"} {
_, err = tx.Exec(`ALTER TABLE "user" DROP COLUMN ` + column)
if err != nil {
return err
}
}
// The whatsmeow_device table is keyed by the full JID, so we need to store the other parts of the JID here too.
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN agent SMALLINT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN device SMALLINT`)
if err != nil {
return err
}
// Clear all usernames, the users need to relogin anyway.
_, err = tx.Exec(`UPDATE "user" SET username=null`)
if err != nil {
return err
}
}
return nil
}}
}

View file

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[26] = upgrade{"Add columns to store infinite backfill pointers for portals", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN first_event_id TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN next_batch_id TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[28] = upgrade{"Add relay user field to portal table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN relay_user_id TEXT`)
return err
}}
}

View file

@ -1,22 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[29] = upgrade{"Replace VARCHAR(255) with TEXT in the Matrix state store", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't enforce varchar sizes anyway
return nil
}
return execMany(tx,
`ALTER TABLE mx_registrations ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT`,
)
}}
}

View file

@ -1,22 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[30] = upgrade{"Store last read message timestamp in database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE user_portal (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_read_ts BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
)`)
return err
}}
}

View file

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[31] = upgrade{"Split last_used into last_encrypted and last_decrypted in crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[5](tx, c.dialect.String())
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
return err
}}
}

View file

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[33] = upgrade{"Add personal filtering space info to user tables", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN space_room TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_space BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View file

@ -1,20 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[34] = upgrade{"Add support for disappearing messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN expiration_time BIGINT NOT NULL DEFAULT 0 CHECK (expiration_time >= 0 AND expiration_time < 4294967296)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE disappearing_message (
room_id TEXT,
event_id TEXT,
expire_in BIGINT NOT NULL,
expire_at BIGINT,
PRIMARY KEY (room_id, event_id)
)`)
return err
}}
}

View file

@ -1,10 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[35] = upgrade{"Store approximate last seen timestamp of the main device", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_seen BIGINT`)
return err
}}
}

View file

@ -1,30 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == Postgres {
_, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')")
if err != nil {
return err
}
}
_, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true")
if err != nil {
return err
}
if ctx.dialect == Postgres {
// TODO do this on sqlite at some point
_, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error")
if err != nil {
return err
}
}
return nil
}}
}

View file

@ -1,10 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`)
return err
}}
}

View file

@ -1,39 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[38] = upgrade{"Add support for reactions", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN type TEXT NOT NULL DEFAULT 'message'`)
if err != nil {
return err
}
if ctx.dialect == Postgres {
_, err = tx.Exec("ALTER TABLE message ALTER COLUMN type DROP DEFAULT")
if err != nil {
return err
}
}
_, err = tx.Exec("UPDATE message SET type='' WHERE error='decryption_failed'")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE message SET type='fake' WHERE jid LIKE 'FAKE::%' OR mxid LIKE 'net.maunium.whatsapp.fake::%' OR jid=mxid")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE reaction (
chat_jid TEXT,
chat_receiver TEXT,
target_jid TEXT,
sender TEXT,
mxid TEXT NOT NULL,
jid TEXT NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
CONSTRAINT target_message_fkey FOREIGN KEY (chat_jid, chat_receiver, target_jid)
REFERENCES message(chat_jid, chat_receiver, jid)
ON DELETE CASCADE ON UPDATE CASCADE
)`)
return err
}}
}

View file

@ -1,45 +0,0 @@
package upgrades
import (
"database/sql"
"fmt"
)
func init() {
upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error {
// The queue_id needs to auto-increment every insertion. For SQLite,
// INTEGER PRIMARY KEY is an alias for the ROWID, so it will
// auto-increment. See https://sqlite.org/lang_createtable.html#rowid
// For Postgres, we need to add GENERATED ALWAYS AS IDENTITY for the
// same functionality.
queueIDColumnTypeModifier := ""
if ctx.dialect == Postgres {
queueIDColumnTypeModifier = "GENERATED ALWAYS AS IDENTITY"
}
_, err := tx.Exec(fmt.Sprintf(`
CREATE TABLE backfill_queue (
queue_id INTEGER PRIMARY KEY %s,
user_mxid TEXT,
type INTEGER NOT NULL,
priority INTEGER NOT NULL,
portal_jid TEXT,
portal_receiver TEXT,
time_start TIMESTAMP,
time_end TIMESTAMP,
max_batch_events INTEGER NOT NULL,
max_total_events INTEGER,
batch_delay INTEGER,
completed_at TIMESTAMP,
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`, queueIDColumnTypeModifier))
if err != nil {
return err
}
return err
}}
}

View file

@ -1,52 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE history_sync_conversation (
user_mxid TEXT,
conversation_id TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_message_timestamp TIMESTAMP,
archived BOOLEAN,
pinned INTEGER,
mute_end_time TIMESTAMP,
disappearing_mode INTEGER,
end_of_history_transfer_type INTEGER,
ephemeral_expiration INTEGER,
marked_as_unread BOOLEAN,
unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
)
`)
if err != nil {
return err
}
_, err = tx.Exec(`
CREATE TABLE history_sync_message (
user_mxid TEXT,
conversation_id TEXT,
message_id TEXT,
timestamp TIMESTAMP,
data BYTEA,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
)
`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -1,20 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
UPDATE backfill_queue
SET type=CASE
WHEN type=1 THEN 200
WHEN type=2 THEN 300
ELSE type
END
WHERE type=1 OR type=2
`)
return err
}}
}

View file

@ -1,26 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE media_backfill_requests (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
event_id TEXT,
media_key BYTEA,
status INTEGER,
error TEXT,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`)
return err
}}
}

View file

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`)
return err
}}
}

View file

@ -1,34 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error {
// First, add dispatch_time TIMESTAMP column
_, err := tx.Exec(`
ALTER TABLE backfill_queue
ADD COLUMN dispatch_time TIMESTAMP
`)
if err != nil {
return err
}
// For all previous jobs, set dispatch time to the completed time.
_, err = tx.Exec(`
UPDATE backfill_queue
SET dispatch_time=completed_at
`)
if err != nil {
return err
}
// Remove time_end from the backfill queue
_, err = tx.Exec(`
ALTER TABLE backfill_queue
DROP COLUMN time_end
`)
return err
}}
}

View file

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error {
// Add the inserted time TIMESTAMP column to history_sync_message
_, err := tx.Exec(`
ALTER TABLE history_sync_message
ADD COLUMN inserted_time TIMESTAMP
`)
return err
}}
}

View file

@ -1,25 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[46] = upgrade{"Create the backfill state table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts INTEGER,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`)
return err
}}
}

View file

@ -0,0 +1,5 @@
-- v45: Add dispatch time to backfill queue
ALTER TABLE backfill_queue ADD COLUMN dispatch_time TIMESTAMP;
UPDATE backfill_queue SET dispatch_time=completed_at;
ALTER TABLE backfill_queue DROP COLUMN time_end;

View file

@ -0,0 +1,3 @@
-- v46: Add inserted time to history sync message
ALTER TABLE history_sync_message ADD COLUMN inserted_time TIMESTAMP;

View file

@ -0,0 +1,13 @@
-- v47: Add table for keeping track of backfill state
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts TIMESTAMP,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
);

View file

@ -0,0 +1,7 @@
-- v48: Move crypto/state/whatsmeow store upgrade handling to separate systems
CREATE TABLE crypto_version (version INTEGER PRIMARY KEY);
INSERT INTO crypto_version VALUES (6);
CREATE TABLE whatsmeow_version (version INTEGER PRIMARY KEY);
INSERT INTO whatsmeow_version VALUES (1);
CREATE TABLE mx_version (version INTEGER PRIMARY KEY);
INSERT INTO mx_version VALUES (1);

View file

@ -1,180 +1,27 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package upgrades package upgrades
import ( import (
"database/sql" "database/sql"
"embed"
"errors" "errors"
"fmt"
"strings"
log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/util/dbutil"
) )
type Dialect int var Table dbutil.UpgradeTable
const ( //go:embed *.sql
Postgres Dialect = iota var rawUpgrades embed.FS
SQLite
)
func (dialect Dialect) String() string { func init() {
switch dialect { Table.Register(-1, 43, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
case Postgres: return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
return "postgres" })
case SQLite: Table.RegisterFS(rawUpgrades)
return "sqlite3"
default:
return ""
}
}
type upgradeFunc func(*sql.Tx, context) error
type context struct {
dialect Dialect
db *sql.DB
log log.Logger
}
type upgrade struct {
message string
fn upgradeFunc
}
const NumberOfUpgrades = 47
var upgrades [NumberOfUpgrades]upgrade
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
var ErrNotOwned = fmt.Errorf("the database is owned by")
var IgnoreForeignTables = false
const databaseOwner = "mautrix-whatsapp"
func GetVersion(db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
if err != nil {
return -1, err
}
version := 0
err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return -1, err
}
return version, nil
}
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
if dialect == SQLite {
_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
} else if dialect == Postgres {
_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
}
return
}
const createOwnerTable = `
CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)
`
func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
var owner string
if !IgnoreForeignTables {
if tableExists(dialect, db, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if tableExists(dialect, db, "goose_db_version") {
return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
}
}
if _, err := db.Exec(createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to check database owner: %w", err)
} else if owner != databaseOwner {
return fmt.Errorf("%w %s", ErrNotOwned, owner)
}
return nil
}
func SetVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec("DELETE FROM version")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
return err
}
func execMany(tx *sql.Tx, queries ...string) error {
for _, query := range queries {
_, err := tx.Exec(query)
if err != nil {
return err
}
}
return nil
}
func Run(log log.Logger, dialectName string, db *sql.DB) error {
var dialect Dialect
switch strings.ToLower(dialectName) {
case "postgres":
dialect = Postgres
case "sqlite3":
dialect = SQLite
default:
return fmt.Errorf("unknown dialect %s", dialectName)
}
err := CheckDatabaseOwner(dialect, db)
if err != nil {
return err
}
version, err := GetVersion(db)
if err != nil {
return err
}
if version > NumberOfUpgrades {
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
}
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
for i, upgradeItem := range upgrades[version:] {
if upgradeItem.fn == nil {
continue
}
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
var tx *sql.Tx
tx, err = db.Begin()
if err != nil {
return err
}
err = upgradeItem.fn(tx, context{dialect, db, log})
if err != nil {
return err
}
err = SetVersion(tx, version+i+1)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
} }

View file

@ -24,6 +24,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
) )
@ -89,7 +90,7 @@ type User struct {
inSpaceCacheLock sync.Mutex inSpaceCacheLock sync.Mutex
} }
func (user *User) Scan(row Scannable) *User { func (user *User) Scan(row dbutil.Scannable) *User {
var username, timezone sql.NullString var username, timezone sql.NullString
var device, agent sql.NullByte var device, agent sql.NullByte
var phoneLastSeen, phoneLastPinged sql.NullInt64 var phoneLastSeen, phoneLastPinged sql.NullInt64

View file

@ -50,9 +50,9 @@ func (portal *Portal) ScheduleDisappearing() {
} }
} }
func (bridge *Bridge) SleepAndDeleteUpcoming() { func (br *WABridge) SleepAndDeleteUpcoming() {
for _, msg := range bridge.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) { for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
portal := bridge.GetPortalByMXID(msg.RoomID) portal := br.GetPortalByMXID(msg.RoomID)
if portal == nil { if portal == nil {
msg.Delete() msg.Delete()
} else { } else {

View file

@ -43,14 +43,6 @@ appservice:
max_conn_idle_time: null max_conn_idle_time: null
max_conn_lifetime: null max_conn_lifetime: null
# Settings for provisioning API
provisioning:
# Prefix for the provisioning API paths.
prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate", a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# The unique ID of this appservice. # The unique ID of this appservice.
id: whatsapp id: whatsapp
# Appservice bot details. # Appservice bot details.
@ -317,6 +309,14 @@ bridge:
# Verification by the bridge is not yet implemented. # Verification by the bridge is not yet implemented.
require_verification: true require_verification: true
# Settings for provisioning API
provisioning:
# Prefix for the provisioning API paths.
prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate", a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# Permissions for using the bridge. # Permissions for using the bridge.
# Permitted values: # Permitted values:
# relay - Talk through the relaybot (if enabled), no access otherwise # relay - Talk through the relaybot (if enabled), no access otherwise

View file

@ -37,7 +37,7 @@ var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids" const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
type Formatter struct { type Formatter struct {
bridge *Bridge bridge *WABridge
matrixHTMLParser *format.HTMLParser matrixHTMLParser *format.HTMLParser
@ -46,7 +46,7 @@ type Formatter struct {
waReplFuncText map[*regexp.Regexp]func(string) string waReplFuncText map[*regexp.Regexp]func(string) string
} }
func NewFormatter(bridge *Bridge) *Formatter { func NewFormatter(bridge *WABridge) *Formatter {
formatter := &Formatter{ formatter := &Formatter{
bridge: bridge, bridge: bridge,
matrixHTMLParser: &format.HTMLParser{ matrixHTMLParser: &format.HTMLParser{

7
go.mod
View file

@ -14,10 +14,8 @@ require (
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/net v0.0.0-20220513224357-95641704303c golang.org/x/net v0.0.0-20220513224357-95641704303c
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.0
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99
maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.3.2 maunium.net/go/maulogger/v2 v2.3.2
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5
) )
require ( require (
@ -37,7 +35,8 @@ require (
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0 // indirect
maunium.net/go/mauflag v1.0.0 // indirect
) )
// Exclude some things that cause go.sum to explode // Exclude some things that cause go.sum to explode

9
go.sum
View file

@ -99,14 +99,13 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc= gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 h1:+KEF+nSuBfHWsfQRz92YP/DdSLbComLoXCXgcrH6WRU= maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 h1:7ZORg2h+lflc1HwjTKCXZnykauXD+wzbW+VDknbv6SU=
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1/go.mod h1:K29EcHwsNg6r7fMfwvi0GHQ9o5wSjqB9+Q8RjCIQEjA= maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5/go.mod h1:oma8o6Y/5jcViBlDbX7tp1ajP2XP+b78h8twdI+zKI0=

515
main.go
View file

@ -18,43 +18,26 @@ package main
import ( import (
_ "embed" _ "embed"
"errors"
"fmt"
"net/http" "net/http"
"os" "os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"google.golang.org/protobuf/proto"
"go.mau.fi/whatsmeow" "go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto" waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types" "go.mau.fi/whatsmeow/types"
"google.golang.org/protobuf/proto"
flag "maunium.net/go/mauflag" "maunium.net/go/mautrix/bridge"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/configupgrade" "maunium.net/go/mautrix/util/configupgrade"
"maunium.net/go/mautrix-whatsapp/config" "maunium.net/go/mautrix-whatsapp/config"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/database/upgrades"
)
// The name and repo URL of the bridge.
var (
Name = "mautrix-whatsapp"
URL = "https://github.com/mautrix/whatsapp"
) )
// Information to find out exactly which commit the bridge was built from. // Information to find out exactly which commit the bridge was built from.
@ -65,120 +48,19 @@ var (
BuildTime = "unknown" BuildTime = "unknown"
) )
var (
// Version is the version number of the bridge. Changed manually when making a release.
Version = "0.4.0"
// WAVersion is the version number exposed to WhatsApp. Filled in init()
WAVersion = ""
// VersionString is the bridge version, plus commit information. Filled in init() using the build-time values.
VersionString = ""
)
//go:embed example-config.yaml //go:embed example-config.yaml
var ExampleConfig string var ExampleConfig string
func init() { type WABridge struct {
if len(Tag) > 0 && Tag[0] == 'v' { bridge.Bridge
Tag = Tag[1:] MatrixHandler *MatrixHandler
} Config *config.Config
if Tag != Version { DB *database.Database
suffix := "" Provisioning *ProvisioningAPI
if !strings.HasSuffix(Version, "+dev") { Formatter *Formatter
suffix = "+dev" Metrics *MetricsHandler
} WAContainer *sqlstore.Container
if len(Commit) > 8 { WAVersion string
Version = fmt.Sprintf("%s%s.%s", Version, suffix, Commit[:8])
} else {
Version = fmt.Sprintf("%s%s.unknown", Version, suffix)
}
}
mautrix.DefaultUserAgent = fmt.Sprintf("mautrix-whatsapp/%s %s", Version, mautrix.DefaultUserAgent)
WAVersion = strings.FieldsFunc(Version, func(r rune) bool { return r == '-' || r == '+' })[0]
VersionString = fmt.Sprintf("%s %s (%s)", Name, Version, BuildTime)
config.ExampleConfig = ExampleConfig
}
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
var wantHelp, _ = flag.MakeHelpFlag()
func (bridge *Bridge) GenerateRegistration() {
if *dontSaveConfig {
// We need to save the generated as_token and hs_token in the config
_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
os.Exit(5)
}
reg, err := bridge.Config.NewRegistration()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to generate registration:", err)
os.Exit(20)
}
err = reg.Save(*registrationPath)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
os.Exit(21)
}
err = config.Mutate(*configPath, func(helper *configupgrade.Helper) {
helper.Set(configupgrade.Str, bridge.Config.AppService.ASToken, "appservice", "as_token")
helper.Set(configupgrade.Str, bridge.Config.AppService.HSToken, "appservice", "hs_token")
})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
os.Exit(22)
}
fmt.Println("Registration generated. Add the path to the registration to your Synapse config, restart it, then start the bridge.")
os.Exit(0)
}
func (bridge *Bridge) MigrateDatabase() {
oldDB, err := database.New(config.DatabaseConfig{Type: flag.Arg(0), URI: flag.Arg(1)}, log.DefaultLogger)
if err != nil {
fmt.Println("Failed to open old database:", err)
os.Exit(30)
}
err = oldDB.Init()
if err != nil {
fmt.Println("Failed to upgrade old database:", err)
os.Exit(31)
}
newDB, err := database.New(bridge.Config.AppService.Database, log.DefaultLogger)
if err != nil {
fmt.Println("Failed to open new database:", err)
os.Exit(32)
}
err = newDB.Init()
if err != nil {
fmt.Println("Failed to upgrade new database:", err)
os.Exit(33)
}
database.Migrate(oldDB, newDB)
}
type Bridge struct {
AS *appservice.AppService
EventProcessor *appservice.EventProcessor
MatrixHandler *MatrixHandler
Config *config.Config
DB *database.Database
Log log.Logger
StateStore *database.SQLStateStore
Provisioning *ProvisioningAPI
Bot *appservice.IntentAPI
Formatter *Formatter
Crypto Crypto
Metrics *MetricsHandler
WAContainer *sqlstore.Container
usersByMXID map[id.UserID]*User usersByMXID map[id.UserID]*User
usersByUsername map[string]*User usersByUsername map[string]*User
@ -195,111 +77,32 @@ type Bridge struct {
puppetsLock sync.Mutex puppetsLock sync.Mutex
} }
type Crypto interface { func (br *WABridge) Init() {
HandleMemberEvent(*event.Event) Segment.log = br.Log.Sub("Segment")
Decrypt(*event.Event) (*event.Event, error) Segment.key = br.Config.SegmentKey
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
Start()
Stop()
}
func (bridge *Bridge) ensureConnection() {
for {
versions, err := bridge.Bot.Versions()
if err != nil {
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
continue
}
if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
bridge.Log.Warnfln("Server isn't advertising modern spec versions")
}
resp, err := bridge.Bot.Whoami()
if err != nil {
if errors.Is(err, mautrix.MUnknownToken) {
bridge.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
os.Exit(16)
} else if errors.Is(err, mautrix.MExclusive) {
bridge.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
os.Exit(16)
}
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
} else if resp.UserID != bridge.Bot.UserID {
bridge.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, bridge.Bot.UserID)
os.Exit(17)
} else {
break
}
}
}
func (bridge *Bridge) Init() {
var err error
bridge.AS, err = bridge.Config.MakeAppService()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
os.Exit(11)
}
_, _ = bridge.AS.Init()
bridge.Log = log.Create()
bridge.Config.Logging.Configure(bridge.Log)
log.DefaultLogger = bridge.Log.(*log.BasicLogger)
if len(bridge.Config.Logging.FileNameFormat) > 0 {
err = log.OpenFile()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
os.Exit(12)
}
}
bridge.AS.Log = log.Sub("Matrix")
bridge.Bot = bridge.AS.BotIntent()
bridge.Log.Infoln("Initializing", VersionString)
bridge.Log.Debugln("Initializing database connection")
bridge.DB, err = database.New(bridge.Config.AppService.Database, bridge.Log)
if err != nil {
bridge.Log.Fatalln("Failed to initialize database connection:", err)
os.Exit(14)
}
bridge.Log.Debugln("Initializing state store")
bridge.StateStore = database.NewSQLStateStore(bridge.DB)
bridge.AS.StateStore = bridge.StateStore
Segment.log = bridge.Log.Sub("Segment")
Segment.key = bridge.Config.SegmentKey
if Segment.IsEnabled() { if Segment.IsEnabled() {
Segment.log.Infoln("Segment metrics are enabled") Segment.log.Infoln("Segment metrics are enabled")
} }
bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil) br.DB = database.New(br.Bridge.DB)
bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError br.WAContainer = sqlstore.NewWithDB(br.DB.DB, br.DB.Dialect.String(), nil)
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
ss := bridge.Config.AppService.Provisioning.SharedSecret ss := br.Config.Bridge.Provisioning.SharedSecret
if len(ss) > 0 && ss != "disable" { if len(ss) > 0 && ss != "disable" {
bridge.Provisioning = &ProvisioningAPI{bridge: bridge} br.Provisioning = &ProvisioningAPI{bridge: br}
} }
bridge.Log.Debugln("Initializing Matrix event processor") br.Log.Debugln("Initializing Matrix event handler")
bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS) br.MatrixHandler = NewMatrixHandler(br)
bridge.Log.Debugln("Initializing Matrix event handler") br.Formatter = NewFormatter(br)
bridge.MatrixHandler = NewMatrixHandler(bridge) br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
bridge.Formatter = NewFormatter(bridge)
bridge.Crypto = NewCryptoHelper(bridge)
bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB)
store.BaseClientPayload.UserAgent.OsVersion = proto.String(WAVersion) store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(WAVersion) store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(br.WAVersion)
store.CompanionProps.Os = proto.String(bridge.Config.WhatsApp.OSName) store.CompanionProps.Os = proto.String(br.Config.WhatsApp.OSName)
store.CompanionProps.RequireFullSync = proto.Bool(bridge.Config.Bridge.HistorySync.RequestFullSync) store.CompanionProps.RequireFullSync = proto.Bool(br.Config.Bridge.HistorySync.RequestFullSync)
versionParts := strings.Split(WAVersion, ".") versionParts := strings.Split(br.WAVersion, ".")
if len(versionParts) > 2 { if len(versionParts) > 2 {
primary, _ := strconv.Atoi(versionParts[0]) primary, _ := strconv.Atoi(versionParts[0])
secondary, _ := strconv.Atoi(versionParts[1]) secondary, _ := strconv.Atoi(versionParts[1])
@ -308,161 +111,107 @@ func (bridge *Bridge) Init() {
store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary)) store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary))
store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary)) store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary))
} }
platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(bridge.Config.WhatsApp.BrowserName)] platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(br.Config.WhatsApp.BrowserName)]
if ok { if ok {
store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum() store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum()
} }
} }
func (bridge *Bridge) Start() { func (br *WABridge) Start() {
bridge.Log.Debugln("Running database upgrades") err := br.WAContainer.Upgrade()
err := bridge.DB.Init() if err != nil {
if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) { br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
bridge.Log.Fatalln("Failed to initialize database:", err)
if errors.Is(err, upgrades.ErrForeignTables) {
bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
} else if errors.Is(err, upgrades.ErrNotOwned) {
bridge.Log.Infoln("Sharing the same database with different programs is not supported")
} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
bridge.Log.Infoln("Downgrading the bridge is not supported")
}
os.Exit(15) os.Exit(15)
} }
bridge.Log.Debugln("Checking connection to homeserver") if br.Provisioning != nil {
bridge.ensureConnection() br.Log.Debugln("Initializing provisioning API")
if bridge.Crypto != nil { br.Provisioning.Init()
err = bridge.Crypto.Init()
if err != nil {
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
os.Exit(19)
}
} }
if bridge.Provisioning != nil { go br.CheckWhatsAppUpdate()
bridge.Log.Debugln("Initializing provisioning API") go br.StartUsers()
bridge.Provisioning.Init() if br.Config.Metrics.Enabled {
} go br.Metrics.Start()
bridge.Log.Debugln("Starting application service HTTP server")
go bridge.AS.Start()
bridge.Log.Debugln("Starting event processor")
go bridge.EventProcessor.Start()
go bridge.CheckWhatsAppUpdate()
go bridge.UpdateBotProfile()
if bridge.Crypto != nil {
go bridge.Crypto.Start()
}
go bridge.StartUsers()
if bridge.Config.Metrics.Enabled {
go bridge.Metrics.Start()
} }
if bridge.Config.Bridge.ResendBridgeInfo { if br.Config.Bridge.ResendBridgeInfo {
go bridge.ResendBridgeInfo() go br.ResendBridgeInfo()
} }
go bridge.Loop() go br.Loop()
bridge.AS.Ready = true
} }
func (bridge *Bridge) CheckWhatsAppUpdate() { func (br *WABridge) CheckWhatsAppUpdate() {
bridge.Log.Debugfln("Checking for WhatsApp web update") br.Log.Debugfln("Checking for WhatsApp web update")
resp, err := whatsmeow.CheckUpdate(http.DefaultClient) resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
if err != nil { if err != nil {
bridge.Log.Warnfln("Failed to check for WhatsApp web update: %v", err) br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
return return
} }
if store.GetWAVersion() == resp.ParsedVersion { if store.GetWAVersion() == resp.ParsedVersion {
bridge.Log.Debugfln("Bridge is using latest WhatsApp web protocol") br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) { } else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
if resp.IsBelowHard || resp.IsBroken { if resp.IsBelowHard || resp.IsBroken {
bridge.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else if resp.IsBelowSoft { } else if resp.IsBelowSoft {
bridge.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else { } else {
bridge.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} }
} else { } else {
bridge.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol") br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
} }
} }
func (bridge *Bridge) Loop() { func (br *WABridge) Loop() {
for { for {
bridge.SleepAndDeleteUpcoming() br.SleepAndDeleteUpcoming()
time.Sleep(1 * time.Hour) time.Sleep(1 * time.Hour)
bridge.WarnUsersAboutDisconnection() br.WarnUsersAboutDisconnection()
} }
} }
func (bridge *Bridge) WarnUsersAboutDisconnection() { func (br *WABridge) WarnUsersAboutDisconnection() {
bridge.usersLock.Lock() br.usersLock.Lock()
for _, user := range bridge.usersByUsername { for _, user := range br.usersByUsername {
if user.IsConnected() && !user.PhoneRecentlySeen(true) { if user.IsConnected() && !user.PhoneRecentlySeen(true) {
go user.sendPhoneOfflineWarning() go user.sendPhoneOfflineWarning()
} }
} }
bridge.usersLock.Unlock() br.usersLock.Unlock()
} }
func (bridge *Bridge) ResendBridgeInfo() { func (br *WABridge) ResendBridgeInfo() {
if *dontSaveConfig { // FIXME
bridge.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag") //if *dontSaveConfig {
} else { // br.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
err := config.Mutate(*configPath, func(helper *configupgrade.Helper) { //} else {
helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") // err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
}) // helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
if err != nil { // })
bridge.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err) // if err != nil {
} // br.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
} // }
bridge.Log.Infoln("Re-sending bridge info state event to all portals") //}
for _, portal := range bridge.GetAllPortals() { //br.Log.Infoln("Re-sending bridge info state event to all portals")
portal.UpdateBridgeInfo() //for _, portal := range br.GetAllPortals() {
} // portal.UpdateBridgeInfo()
bridge.Log.Infoln("Finished re-sending bridge info state events") //}
//br.Log.Infoln("Finished re-sending bridge info state events")
} }
func (bridge *Bridge) UpdateBotProfile() { func (br *WABridge) StartUsers() {
bridge.Log.Debugln("Updating bot profile") br.Log.Debugln("Starting users")
botConfig := &bridge.Config.AppService.Bot
var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" {
err = bridge.Bot.SetAvatarURL(mxc)
} else if len(botConfig.Avatar) > 0 {
mxc, err = id.ParseContentURI(botConfig.Avatar)
if err == nil {
err = bridge.Bot.SetAvatarURL(mxc)
}
botConfig.ParsedAvatar = mxc
}
if err != nil {
bridge.Log.Warnln("Failed to update bot avatar:", err)
}
if botConfig.Displayname == "remove" {
err = bridge.Bot.SetDisplayName("")
} else if len(botConfig.Displayname) > 0 {
err = bridge.Bot.SetDisplayName(botConfig.Displayname)
}
if err != nil {
bridge.Log.Warnln("Failed to update bot displayname:", err)
}
}
func (bridge *Bridge) StartUsers() {
bridge.Log.Debugln("Starting users")
foundAnySessions := false foundAnySessions := false
for _, user := range bridge.GetAllUsers() { for _, user := range br.GetAllUsers() {
if !user.JID.IsEmpty() { if !user.JID.IsEmpty() {
foundAnySessions = true foundAnySessions = true
} }
go user.Connect() go user.Connect()
} }
if !foundAnySessions { if !foundAnySessions {
bridge.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil)) br.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
} }
bridge.Log.Debugln("Starting custom puppets") br.Log.Debugln("Starting custom puppets")
for _, loopuppet := range bridge.GetAllPuppetsWithCustomMXID() { for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
go func(puppet *Puppet) { go func(puppet *Puppet) {
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID) puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
err := puppet.StartCustomMXID(true) err := puppet.StartCustomMXID(true)
@ -473,80 +222,37 @@ func (bridge *Bridge) StartUsers() {
} }
} }
func (bridge *Bridge) Stop() { func (br *WABridge) Stop() {
if bridge.Crypto != nil { if br.Crypto != nil {
bridge.Crypto.Stop() br.Crypto.Stop()
} }
bridge.AS.Stop() br.AS.Stop()
bridge.Metrics.Stop() br.Metrics.Stop()
bridge.EventProcessor.Stop() br.EventProcessor.Stop()
for _, user := range bridge.usersByUsername { for _, user := range br.usersByUsername {
if user.Client == nil { if user.Client == nil {
continue continue
} }
bridge.Log.Debugln("Disconnecting", user.MXID) br.Log.Debugln("Disconnecting", user.MXID)
user.Client.Disconnect() user.Client.Disconnect()
close(user.historySyncs) close(user.historySyncs)
} }
} }
func (bridge *Bridge) Main() { func (br *WABridge) GetExampleConfig() string {
configData, upgraded, err := config.Upgrade(*configPath, !*dontSaveConfig) return ExampleConfig
if err != nil { }
_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
if configData == nil { func (br *WABridge) GetConfigPtr() interface{} {
os.Exit(10) br.Config = &config.Config{
} BaseConfig: &br.Bridge.Config,
} }
br.Config.BaseConfig.Bridge = &br.Config.Bridge
bridge.Config, err = config.Load(configData, upgraded) return br.Config
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
os.Exit(10)
}
if *generateRegistration {
bridge.GenerateRegistration()
return
} else if *migrateFrom {
bridge.MigrateDatabase()
return
}
bridge.Init()
bridge.Log.Infoln("Bridge initialization complete, starting...")
bridge.Start()
bridge.Log.Infoln("Bridge started!")
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
bridge.Log.Infoln("Interrupt received, stopping...")
bridge.Stop()
bridge.Log.Infoln("Bridge stopped.")
os.Exit(0)
} }
func main() { func main() {
flag.SetHelpTitles( br := &WABridge{
"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.",
"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]")
err := flag.Parse()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
flag.PrintHelp()
os.Exit(1)
} else if *wantHelp {
flag.PrintHelp()
os.Exit(0)
} else if *version {
fmt.Println(VersionString)
return
}
upgrades.IgnoreForeignTables = *ignoreForeignTables
(&Bridge{
usersByMXID: make(map[id.UserID]*User), usersByMXID: make(map[id.UserID]*User),
usersByUsername: make(map[string]*User), usersByUsername: make(map[string]*User),
spaceRooms: make(map[id.RoomID]*User), spaceRooms: make(map[id.RoomID]*User),
@ -555,5 +261,24 @@ func main() {
portalsByJID: make(map[database.PortalKey]*Portal), portalsByJID: make(map[database.PortalKey]*Portal),
puppets: make(map[types.JID]*Puppet), puppets: make(map[types.JID]*Puppet),
puppetsByCustomMXID: make(map[id.UserID]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet),
}).Main() }
br.Bridge = bridge.Bridge{
Name: "mautrix-whatsapp",
URL: "https://github.com/mautrix/whatsapp",
Description: "A Matrix-WhatsApp puppeting bridge.",
Version: "0.4.0",
ProtocolName: "WhatsApp",
ConfigUpgrader: &configupgrade.StructUpgrader{
SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade),
Blocks: config.SpacedBlocks,
Base: ExampleConfig,
},
Child: br,
}
br.InitVersion(Tag, Commit, BuildTime)
br.WAVersion = strings.FieldsFunc(br.Version, func(r rune) bool { return r == '-' || r == '+' })[0]
br.Main()
} }

View file

@ -28,6 +28,7 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@ -36,13 +37,13 @@ import (
) )
type MatrixHandler struct { type MatrixHandler struct {
bridge *Bridge bridge *WABridge
as *appservice.AppService as *appservice.AppService
log maulogger.Logger log maulogger.Logger
cmd *CommandHandler cmd *CommandHandler
} }
func NewMatrixHandler(bridge *Bridge) *MatrixHandler { func NewMatrixHandler(bridge *WABridge) *MatrixHandler {
handler := &MatrixHandler{ handler := &MatrixHandler{
bridge: bridge, bridge: bridge,
as: bridge.AS, as: bridge.AS,
@ -362,7 +363,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
decrypted, err := mx.bridge.Crypto.Decrypt(evt) decrypted, err := mx.bridge.Crypto.Decrypt(evt)
decryptionRetryCount := 0 decryptionRetryCount := 0
if errors.Is(err, NoSessionFound) { if errors.Is(err, bridge.NoSessionFound) {
content := evt.Content.AsEncrypted() content := evt.Content.AsEncrypted()
mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds())) mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds()))
mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount) mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount)

View file

@ -1,17 +0,0 @@
//go:build !cgo || nocrypto
package main
import (
"errors"
)
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
}
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
return nil
}
var NoSessionFound = errors.New("nil")

View file

@ -45,6 +45,7 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
@ -69,68 +70,72 @@ const PrivateChatTopic = "WhatsApp private chat"
var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled") var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled")
func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal { func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal {
bridge.portalsLock.Lock() br.portalsLock.Lock()
defer bridge.portalsLock.Unlock() defer br.portalsLock.Unlock()
portal, ok := bridge.portalsByMXID[mxid] portal, ok := br.portalsByMXID[mxid]
if !ok { if !ok {
return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil) return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil)
} }
return portal return portal
} }
func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal { func (br *WABridge) GetIPortalByMXID(mxid id.RoomID) bridge.Portal {
bridge.portalsLock.Lock() return br.GetPortalByMXID(mxid)
defer bridge.portalsLock.Unlock() }
portal, ok := bridge.portalsByJID[key]
func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal {
br.portalsLock.Lock()
defer br.portalsLock.Unlock()
portal, ok := br.portalsByJID[key]
if !ok { if !ok {
return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key) return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key)
} }
return portal return portal
} }
func (bridge *Bridge) GetAllPortals() []*Portal { func (br *WABridge) GetAllPortals() []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll()) return br.dbPortalsToPortals(br.DB.Portal.GetAll())
} }
func (bridge *Bridge) GetAllPortalsForUser(userID id.UserID) []*Portal { func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllForUser(userID)) return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID))
} }
func (bridge *Bridge) GetAllPortalsByJID(jid types.JID) []*Portal { func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid)) return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid))
} }
func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
bridge.portalsLock.Lock() br.portalsLock.Lock()
defer bridge.portalsLock.Unlock() defer br.portalsLock.Unlock()
output := make([]*Portal, len(dbPortals)) output := make([]*Portal, len(dbPortals))
for index, dbPortal := range dbPortals { for index, dbPortal := range dbPortals {
if dbPortal == nil { if dbPortal == nil {
continue continue
} }
portal, ok := bridge.portalsByJID[dbPortal.Key] portal, ok := br.portalsByJID[dbPortal.Key]
if !ok { if !ok {
portal = bridge.loadDBPortal(dbPortal, nil) portal = br.loadDBPortal(dbPortal, nil)
} }
output[index] = portal output[index] = portal
} }
return output return output
} }
func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
if dbPortal == nil { if dbPortal == nil {
if key == nil { if key == nil {
return nil return nil
} }
dbPortal = bridge.DB.Portal.New() dbPortal = br.DB.Portal.New()
dbPortal.Key = *key dbPortal.Key = *key
dbPortal.Insert() dbPortal.Insert()
} }
portal := bridge.NewPortal(dbPortal) portal := br.NewPortal(dbPortal)
bridge.portalsByJID[portal.Key] = portal br.portalsByJID[portal.Key] = portal
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
bridge.portalsByMXID[portal.MXID] = portal br.portalsByMXID[portal.MXID] = portal
} }
return portal return portal
} }
@ -139,14 +144,14 @@ func (portal *Portal) GetUsers() []*User {
return nil return nil
} }
func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal { func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal {
portal := &Portal{ portal := &Portal{
bridge: bridge, bridge: br,
log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)), log: br.Log.Sub(fmt.Sprintf("Portal/%s", key)),
messages: make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer), messages: make(chan PortalMessage, br.Config.Bridge.PortalMessageBuffer),
matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer), matrixMessages: make(chan PortalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
mediaRetries: make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer), mediaRetries: make(chan PortalMediaRetry, br.Config.Bridge.PortalMessageBuffer),
mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta), mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
} }
@ -154,15 +159,15 @@ func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
return portal return portal
} }
func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal { func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal {
portal := bridge.newBlankPortal(key) portal := br.newBlankPortal(key)
portal.Portal = bridge.DB.Portal.New() portal.Portal = br.DB.Portal.New()
portal.Key = key portal.Key = key
return portal return portal
} }
func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal { func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal {
portal := bridge.newBlankPortal(dbPortal.Key) portal := br.newBlankPortal(dbPortal.Key)
portal.Portal = dbPortal portal.Portal = dbPortal
return portal return portal
} }
@ -203,7 +208,7 @@ type recentlyHandledWrapper struct {
type Portal struct { type Portal struct {
*database.Portal *database.Portal
bridge *Bridge bridge *WABridge
log log.Logger log log.Logger
roomCreateLock sync.Mutex roomCreateLock sync.Mutex
@ -229,6 +234,10 @@ type Portal struct {
relayUser *User relayUser *User
} }
func (portal *Portal) IsEncrypted() bool {
return portal.Encrypted
}
func (portal *Portal) handleMessageLoopItem(msg PortalMessage) { func (portal *Portal) handleMessageLoopItem(msg PortalMessage) {
if len(portal.MXID) == 0 { if len(portal.MXID) == 0 {
if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) { if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) {

View file

@ -43,15 +43,15 @@ import (
) )
type ProvisioningAPI struct { type ProvisioningAPI struct {
bridge *Bridge bridge *WABridge
log log.Logger log log.Logger
} }
func (prov *ProvisioningAPI) Init() { func (prov *ProvisioningAPI) Init() {
prov.log = prov.bridge.Log.Sub("Provisioning") prov.log = prov.bridge.Log.Sub("Provisioning")
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix) prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter() r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
r.Use(prov.AuthMiddleware) r.Use(prov.AuthMiddleware)
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet) r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet) r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@ -109,7 +109,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
} else if strings.HasPrefix(auth, "Bearer ") { } else if strings.HasPrefix(auth, "Bearer ") {
auth = auth[len("Bearer "):] auth = auth[len("Bearer "):]
} }
if auth != prov.bridge.Config.AppService.Provisioning.SharedSecret { if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
jsonResponse(w, http.StatusForbidden, map[string]interface{}{ jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Invalid auth token", "error": "Invalid auth token",
"errcode": "M_FORBIDDEN", "errcode": "M_FORBIDDEN",

View file

@ -39,11 +39,11 @@ import (
var userIDRegex *regexp.Regexp var userIDRegex *regexp.Regexp
func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) { func (br *WABridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
if userIDRegex == nil { if userIDRegex == nil {
userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$", userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
bridge.Config.Bridge.FormatUsername("([0-9]+)"), br.Config.Bridge.FormatUsername("([0-9]+)"),
bridge.Config.Homeserver.Domain)) br.Config.Homeserver.Domain))
} }
match := userIDRegex.FindStringSubmatch(string(mxid)) match := userIDRegex.FindStringSubmatch(string(mxid))
if len(match) == 2 { if len(match) == 2 {
@ -53,79 +53,79 @@ func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
return return
} }
func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet { func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
jid, ok := bridge.ParsePuppetMXID(mxid) jid, ok := br.ParsePuppetMXID(mxid)
if !ok { if !ok {
return nil return nil
} }
return bridge.GetPuppetByJID(jid) return br.GetPuppetByJID(jid)
} }
func (bridge *Bridge) GetPuppetByJID(jid types.JID) *Puppet { func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
jid = jid.ToNonAD() jid = jid.ToNonAD()
if jid.Server == types.LegacyUserServer { if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer jid.Server = types.DefaultUserServer
} else if jid.Server != types.DefaultUserServer { } else if jid.Server != types.DefaultUserServer {
return nil return nil
} }
bridge.puppetsLock.Lock() br.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
puppet, ok := bridge.puppets[jid] puppet, ok := br.puppets[jid]
if !ok { if !ok {
dbPuppet := bridge.DB.Puppet.Get(jid) dbPuppet := br.DB.Puppet.Get(jid)
if dbPuppet == nil { if dbPuppet == nil {
dbPuppet = bridge.DB.Puppet.New() dbPuppet = br.DB.Puppet.New()
dbPuppet.JID = jid dbPuppet.JID = jid
dbPuppet.Insert() dbPuppet.Insert()
} }
puppet = bridge.NewPuppet(dbPuppet) puppet = br.NewPuppet(dbPuppet)
bridge.puppets[puppet.JID] = puppet br.puppets[puppet.JID] = puppet
if len(puppet.CustomMXID) > 0 { if len(puppet.CustomMXID) > 0 {
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
} }
} }
return puppet return puppet
} }
func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
bridge.puppetsLock.Lock() br.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
puppet, ok := bridge.puppetsByCustomMXID[mxid] puppet, ok := br.puppetsByCustomMXID[mxid]
if !ok { if !ok {
dbPuppet := bridge.DB.Puppet.GetByCustomMXID(mxid) dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
if dbPuppet == nil { if dbPuppet == nil {
return nil return nil
} }
puppet = bridge.NewPuppet(dbPuppet) puppet = br.NewPuppet(dbPuppet)
bridge.puppets[puppet.JID] = puppet br.puppets[puppet.JID] = puppet
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
} }
return puppet return puppet
} }
func (bridge *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet { func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAllWithCustomMXID()) return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
} }
func (bridge *Bridge) GetAllPuppets() []*Puppet { func (br *WABridge) GetAllPuppets() []*Puppet {
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAll()) return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
} }
func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
bridge.puppetsLock.Lock() br.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock() defer br.puppetsLock.Unlock()
output := make([]*Puppet, len(dbPuppets)) output := make([]*Puppet, len(dbPuppets))
for index, dbPuppet := range dbPuppets { for index, dbPuppet := range dbPuppets {
if dbPuppet == nil { if dbPuppet == nil {
continue continue
} }
puppet, ok := bridge.puppets[dbPuppet.JID] puppet, ok := br.puppets[dbPuppet.JID]
if !ok { if !ok {
puppet = bridge.NewPuppet(dbPuppet) puppet = br.NewPuppet(dbPuppet)
bridge.puppets[dbPuppet.JID] = puppet br.puppets[dbPuppet.JID] = puppet
if len(dbPuppet.CustomMXID) > 0 { if len(dbPuppet.CustomMXID) > 0 {
bridge.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
} }
} }
output[index] = puppet output[index] = puppet
@ -133,26 +133,26 @@ func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet
return output return output
} }
func (bridge *Bridge) FormatPuppetMXID(jid types.JID) id.UserID { func (br *WABridge) FormatPuppetMXID(jid types.JID) id.UserID {
return id.NewUserID( return id.NewUserID(
bridge.Config.Bridge.FormatUsername(jid.User), br.Config.Bridge.FormatUsername(jid.User),
bridge.Config.Homeserver.Domain) br.Config.Homeserver.Domain)
} }
func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{ return &Puppet{
Puppet: dbPuppet, Puppet: dbPuppet,
bridge: bridge, bridge: br,
log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
MXID: bridge.FormatPuppetMXID(dbPuppet.JID), MXID: br.FormatPuppetMXID(dbPuppet.JID),
} }
} }
type Puppet struct { type Puppet struct {
*database.Puppet *database.Puppet
bridge *Bridge bridge *WABridge
log log.Logger log log.Logger
typingIn id.RoomID typingIn id.RoomID

75
user.go
View file

@ -35,6 +35,7 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@ -56,7 +57,7 @@ type User struct {
Client *whatsmeow.Client Client *whatsmeow.Client
Session *store.Device Session *store.Device
bridge *Bridge bridge *WABridge
log log.Logger log log.Logger
Admin bool Admin bool
@ -84,38 +85,46 @@ type User struct {
BackfillQueue *BackfillQueue BackfillQueue *BackfillQueue
} }
func (bridge *Bridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User { func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
_, isPuppet := bridge.ParsePuppetMXID(userID) _, isPuppet := br.ParsePuppetMXID(userID)
if isPuppet || userID == bridge.Bot.UserID { if isPuppet || userID == br.Bot.UserID {
return nil return nil
} }
bridge.usersLock.Lock() br.usersLock.Lock()
defer bridge.usersLock.Unlock() defer br.usersLock.Unlock()
user, ok := bridge.usersByMXID[userID] user, ok := br.usersByMXID[userID]
if !ok { if !ok {
userIDPtr := &userID userIDPtr := &userID
if onlyIfExists { if onlyIfExists {
userIDPtr = nil userIDPtr = nil
} }
return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), userIDPtr) return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr)
} }
return user return user
} }
func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User { func (br *WABridge) GetUserByMXID(userID id.UserID) *User {
return bridge.getUserByMXID(userID, false) return br.getUserByMXID(userID, false)
} }
func (bridge *Bridge) GetUserByMXIDIfExists(userID id.UserID) *User { func (br *WABridge) GetIUserByMXID(userID id.UserID) bridge.User {
return bridge.getUserByMXID(userID, true) return br.getUserByMXID(userID, false)
} }
func (bridge *Bridge) GetUserByJID(jid types.JID) *User { func (user *User) IsAdmin() bool {
bridge.usersLock.Lock() return user.Admin
defer bridge.usersLock.Unlock() }
user, ok := bridge.usersByUsername[jid.User]
func (br *WABridge) GetUserByMXIDIfExists(userID id.UserID) *User {
return br.getUserByMXID(userID, true)
}
func (br *WABridge) GetUserByJID(jid types.JID) *User {
br.usersLock.Lock()
defer br.usersLock.Unlock()
user, ok := br.usersByUsername[jid.User]
if !ok { if !ok {
return bridge.loadDBUser(bridge.DB.User.GetByUsername(jid.User), nil) return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil)
} }
return user return user
} }
@ -137,35 +146,35 @@ func (user *User) removeFromJIDMap(state BridgeState) {
user.sendBridgeState(state) user.sendBridgeState(state)
} }
func (bridge *Bridge) GetAllUsers() []*User { func (br *WABridge) GetAllUsers() []*User {
bridge.usersLock.Lock() br.usersLock.Lock()
defer bridge.usersLock.Unlock() defer br.usersLock.Unlock()
dbUsers := bridge.DB.User.GetAll() dbUsers := br.DB.User.GetAll()
output := make([]*User, len(dbUsers)) output := make([]*User, len(dbUsers))
for index, dbUser := range dbUsers { for index, dbUser := range dbUsers {
user, ok := bridge.usersByMXID[dbUser.MXID] user, ok := br.usersByMXID[dbUser.MXID]
if !ok { if !ok {
user = bridge.loadDBUser(dbUser, nil) user = br.loadDBUser(dbUser, nil)
} }
output[index] = user output[index] = user
} }
return output return output
} }
func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
if dbUser == nil { if dbUser == nil {
if mxid == nil { if mxid == nil {
return nil return nil
} }
dbUser = bridge.DB.User.New() dbUser = br.DB.User.New()
dbUser.MXID = *mxid dbUser.MXID = *mxid
dbUser.Insert() dbUser.Insert()
} }
user := bridge.NewUser(dbUser) user := br.NewUser(dbUser)
bridge.usersByMXID[user.MXID] = user br.usersByMXID[user.MXID] = user
if !user.JID.IsEmpty() { if !user.JID.IsEmpty() {
var err error var err error
user.Session, err = bridge.WAContainer.GetDevice(user.JID) user.Session, err = br.WAContainer.GetDevice(user.JID)
if err != nil { if err != nil {
user.log.Errorfln("Failed to load user's whatsapp session: %v", err) user.log.Errorfln("Failed to load user's whatsapp session: %v", err)
} else if user.Session == nil { } else if user.Session == nil {
@ -174,20 +183,20 @@ func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
user.Update() user.Update()
} else { } else {
user.Session.Log = &waLogger{user.log.Sub("Session")} user.Session.Log = &waLogger{user.log.Sub("Session")}
bridge.usersByUsername[user.JID.User] = user br.usersByUsername[user.JID.User] = user
} }
} }
if len(user.ManagementRoom) > 0 { if len(user.ManagementRoom) > 0 {
bridge.managementRooms[user.ManagementRoom] = user br.managementRooms[user.ManagementRoom] = user
} }
return user return user
} }
func (bridge *Bridge) NewUser(dbUser *database.User) *User { func (br *WABridge) NewUser(dbUser *database.User) *User {
user := &User{ user := &User{
User: dbUser, User: dbUser,
bridge: bridge, bridge: br,
log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)), log: br.Log.Sub("User").Sub(string(dbUser.MXID)),
historySyncs: make(chan *events.HistorySync, 32), historySyncs: make(chan *events.HistorySync, 32),
lastPresence: types.PresenceUnavailable, lastPresence: types.PresenceUnavailable,