diff --git a/bridgestate.go b/bridgestate.go index 492f4f8..2b66a8e 100644 --- a/bridgestate.go +++ b/bridgestate.go @@ -114,18 +114,18 @@ func (pong *BridgeState) shouldDeduplicate(newPong *BridgeState) bool { return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix() } -func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) error { +func (br *WABridge) sendBridgeState(ctx context.Context, state *BridgeState) error { var body bytes.Buffer if err := json.NewEncoder(&body).Encode(&state); err != nil { return fmt.Errorf("failed to encode bridge state JSON: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, bridge.Config.Homeserver.StatusEndpoint, &body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, br.Config.Homeserver.StatusEndpoint, &body) if err != nil { return fmt.Errorf("failed to prepare request: %w", err) } - req.Header.Set("Authorization", "Bearer "+bridge.Config.AppService.ASToken) + req.Header.Set("Authorization", "Bearer "+br.Config.AppService.ASToken) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) @@ -143,17 +143,17 @@ func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) e return nil } -func (bridge *Bridge) sendGlobalBridgeState(state BridgeState) { - if len(bridge.Config.Homeserver.StatusEndpoint) == 0 { +func (br *WABridge) sendGlobalBridgeState(state BridgeState) { + if len(br.Config.Homeserver.StatusEndpoint) == 0 { return } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err := bridge.sendBridgeState(ctx, &state); err != nil { - bridge.Log.Warnln("Failed to update global bridge state:", err) + if err := br.sendBridgeState(ctx, &state); err != nil { + br.Log.Warnln("Failed to update global bridge state:", err) } else { - bridge.Log.Debugfln("Sent new global bridge state %+v", state) + br.Log.Debugfln("Sent new global bridge state %+v", state) } } diff --git a/commands.go b/commands.go index c19a361..13a27f8 100644 --- a/commands.go +++ b/commands.go @@ -39,6 +39,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" @@ -47,12 +48,12 @@ import ( ) type CommandHandler struct { - bridge *Bridge + bridge *WABridge log maulogger.Logger } // NewCommandHandler creates a CommandHandler -func NewCommandHandler(bridge *Bridge) *CommandHandler { +func NewCommandHandler(bridge *WABridge) *CommandHandler { return &CommandHandler{ bridge: bridge, log: bridge.Log.Sub("Command handler"), @@ -62,7 +63,7 @@ func NewCommandHandler(bridge *Bridge) *CommandHandler { // CommandEvent stores all data which might be used to handle commands type CommandEvent struct { Bot *appservice.IntentAPI - Bridge *Bridge + Bridge *WABridge Portal *Portal Handler *CommandHandler RoomID id.RoomID @@ -251,13 +252,7 @@ func (handler *CommandHandler) CommandDevTest(_ *CommandEvent) { const cmdVersionHelp = `version - View the bridge version` func (handler *CommandHandler) CommandVersion(ce *CommandEvent) { - linkifiedVersion := fmt.Sprintf("v%s", Version) - if Tag == Version { - linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", Version, URL, Tag) - } else if len(Commit) > 8 { - linkifiedVersion = strings.Replace(linkifiedVersion, Commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", Commit[:8], URL, Commit), 1) - } - ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", Name, URL, linkifiedVersion, BuildTime)) + ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, BuildTime)) } const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.` @@ -331,7 +326,7 @@ func (handler *CommandHandler) CommandJoin(ce *CommandEvent) { ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid) } -func tryDecryptEvent(crypto Crypto, evt *event.Event) (json.RawMessage, error) { +func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) { var data json.RawMessage if evt.Type != event.EventEncrypted { data = evt.Content.VeryRaw @@ -903,7 +898,7 @@ func matchesQuery(str string, query string) bool { return strings.Contains(strings.ToLower(str), query) } -func formatContacts(bridge *Bridge, input map[types.JID]types.ContactInfo, query string) (result []string) { +func formatContacts(bridge *WABridge, input map[types.JID]types.ContactInfo, query string) (result []string) { hasQuery := len(query) > 0 for jid, contact := range input { if len(contact.FullName) == 0 { diff --git a/config/bridge.go b/config/bridge.go index 6127124..24813d5 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -24,6 +24,7 @@ import ( "go.mau.fi/whatsmeow/types" + "maunium.net/go/mautrix/bridge/bridgeconfig" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -118,23 +119,23 @@ type BridgeConfig struct { AdditionalHelp string `yaml:"additional_help"` } `yaml:"management_room_text"` - Encryption struct { - Allow bool `yaml:"allow"` - Default bool `yaml:"default"` + Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"` - KeySharing struct { - Allow bool `yaml:"allow"` - RequireCrossSigning bool `yaml:"require_cross_signing"` - RequireVerification bool `yaml:"require_verification"` - } `yaml:"key_sharing"` - } `yaml:"encryption"` + Provisioning struct { + Prefix string `yaml:"prefix"` + SharedSecret string `yaml:"shared_secret"` + } `yaml:"provisioning"` Permissions PermissionConfig `yaml:"permissions"` Relay RelaybotConfig `yaml:"relay"` - usernameTemplate *template.Template `yaml:"-"` - displaynameTemplate *template.Template `yaml:"-"` + ParsedUsernameTemplate *template.Template `yaml:"-"` + displaynameTemplate *template.Template `yaml:"-"` +} + +func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig { + return bc.Encryption } type umBridgeConfig BridgeConfig @@ -145,7 +146,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { return err } - bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate) + bc.ParsedUsernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate) if err != nil { return err } else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") { @@ -206,7 +207,7 @@ func (bc BridgeConfig) FormatDisplayname(jid types.JID, contact types.ContactInf func (bc BridgeConfig) FormatUsername(username string) string { var buf strings.Builder - _ = bc.usernameTemplate.Execute(&buf, username) + _ = bc.ParsedUsernameTemplate.Execute(&buf, username) return buf.String() } diff --git a/config/config.go b/config/config.go index a906682..2379f2e 100644 --- a/config/config.go +++ b/config/config.go @@ -17,52 +17,12 @@ package config import ( - "fmt" - - "gopkg.in/yaml.v3" - - "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/bridgeconfig" "maunium.net/go/mautrix/id" ) -var ExampleConfig string - type Config struct { - Homeserver struct { - Address string `yaml:"address"` - Domain string `yaml:"domain"` - Asmux bool `yaml:"asmux"` - StatusEndpoint string `yaml:"status_endpoint"` - MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` - AsyncMedia bool `yaml:"async_media"` - } `yaml:"homeserver"` - - AppService struct { - Address string `yaml:"address"` - Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` - - Database DatabaseConfig `yaml:"database"` - - Provisioning struct { - Prefix string `yaml:"prefix"` - SharedSecret string `yaml:"shared_secret"` - } `yaml:"provisioning"` - - ID string `yaml:"id"` - Bot struct { - Username string `yaml:"username"` - Displayname string `yaml:"displayname"` - Avatar string `yaml:"avatar"` - - ParsedAvatar id.ContentURI `yaml:"-"` - } `yaml:"bot"` - - EphemeralEvents bool `yaml:"ephemeral_events"` - - ASToken string `yaml:"as_token"` - HSToken string `yaml:"hs_token"` - } `yaml:"appservice"` + *bridgeconfig.BaseConfig `yaml:",inline"` SegmentKey string `yaml:"segment_key"` @@ -77,8 +37,6 @@ type Config struct { } `yaml:"whatsapp"` Bridge BridgeConfig `yaml:"bridge"` - - Logging appservice.LogConfig `yaml:"logging"` } func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool { @@ -98,44 +56,3 @@ func (config *Config) CanDoublePuppetBackfill(userID id.UserID) bool { } return true } - -func Load(data []byte, upgraded bool) (*Config, error) { - var config = &Config{} - if !upgraded { - // Fallback: if config upgrading failed, load example config for base values - err := yaml.Unmarshal([]byte(ExampleConfig), config) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal example config: %w", err) - } - } - err := yaml.Unmarshal(data, config) - if err != nil { - return nil, err - } - - return config, err -} - -func (config *Config) MakeAppService() (*appservice.AppService, error) { - as := appservice.Create() - as.HomeserverDomain = config.Homeserver.Domain - as.HomeserverURL = config.Homeserver.Address - as.Host.Hostname = config.AppService.Hostname - as.Host.Port = config.AppService.Port - as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint - as.DefaultHTTPRetries = 4 - var err error - as.Registration, err = config.GetRegistration() - return as, err -} - -type DatabaseConfig struct { - Type string `yaml:"type"` - URI string `yaml:"uri"` - - MaxOpenConns int `yaml:"max_open_conns"` - MaxIdleConns int `yaml:"max_idle_conns"` - - ConnMaxIdleTime string `yaml:"conn_max_idle_time"` - ConnMaxLifetime string `yaml:"conn_max_lifetime"` -} diff --git a/config/registration.go b/config/registration.go deleted file mode 100644 index 8fdcef0..0000000 --- a/config/registration.go +++ /dev/null @@ -1,82 +0,0 @@ -// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2019 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package config - -import ( - "fmt" - "regexp" - "strings" - - "maunium.net/go/mautrix/appservice" -) - -func (config *Config) NewRegistration() (*appservice.Registration, error) { - registration := appservice.CreateRegistration() - - err := config.copyToRegistration(registration) - if err != nil { - return nil, err - } - - config.AppService.ASToken = registration.AppToken - config.AppService.HSToken = registration.ServerToken - - // Workaround for https://github.com/matrix-org/synapse/pull/5758 - registration.SenderLocalpart = appservice.RandomString(32) - botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", - regexp.QuoteMeta(config.AppService.Bot.Username), - regexp.QuoteMeta(config.Homeserver.Domain))) - registration.Namespaces.RegisterUserIDs(botRegex, true) - - return registration, nil -} - -func (config *Config) GetRegistration() (*appservice.Registration, error) { - registration := appservice.CreateRegistration() - - err := config.copyToRegistration(registration) - if err != nil { - return nil, err - } - - registration.AppToken = config.AppService.ASToken - registration.ServerToken = config.AppService.HSToken - return registration, nil -} - -func (config *Config) copyToRegistration(registration *appservice.Registration) error { - registration.ID = config.AppService.ID - registration.URL = config.AppService.Address - falseVal := false - registration.RateLimited = &falseVal - registration.SenderLocalpart = config.AppService.Bot.Username - registration.EphemeralEvents = config.AppService.EphemeralEvents - - usernamePlaceholder := appservice.RandomString(16) - usernameTemplate := fmt.Sprintf("@%s:%s", - config.Bridge.FormatUsername(usernamePlaceholder), - config.Homeserver.Domain) - usernameTemplate = regexp.QuoteMeta(usernameTemplate) - usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1) - usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) - userIDRegex, err := regexp.Compile(usernameTemplate) - if err != nil { - return err - } - registration.Namespaces.RegisterUserIDs(userIDRegex, true) - return nil -} diff --git a/config/upgrade.go b/config/upgrade.go index 6c6d681..49b30fc 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -20,50 +20,12 @@ import ( "strings" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/bridgeconfig" up "maunium.net/go/mautrix/util/configupgrade" ) -type waUpgrader struct{} - -func (wau waUpgrader) GetBase() string { - return ExampleConfig -} - -func (wau waUpgrader) DoUpgrade(helper *up.Helper) { - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - helper.Copy(up.Bool, "homeserver", "asmux") - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - - helper.Copy(up.Str, "appservice", "address") - helper.Copy(up.Str, "appservice", "hostname") - helper.Copy(up.Int, "appservice", "port") - helper.Copy(up.Str, "appservice", "database", "type") - helper.Copy(up.Str, "appservice", "database", "uri") - helper.Copy(up.Int, "appservice", "database", "max_open_conns") - helper.Copy(up.Int, "appservice", "database", "max_idle_conns") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") - if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok && strings.HasSuffix(prefix, "/v1") { - helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "appservice", "provisioning", "prefix") - } else { - helper.Copy(up.Str, "appservice", "provisioning", "prefix") - } - if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); !ok || secret == "generate" { - sharedSecret := appservice.RandomString(64) - helper.Set(up.Str, sharedSecret, "appservice", "provisioning", "shared_secret") - } else { - helper.Copy(up.Str, "appservice", "provisioning", "shared_secret") - } - helper.Copy(up.Str, "appservice", "id") - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") +func DoUpgrade(helper *up.Helper) { + bridgeconfig.Upgrader.DoUpgrade(helper) helper.Copy(up.Str|up.Null, "segment_key") @@ -134,46 +96,41 @@ func (wau waUpgrader) DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow") helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing") helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification") + if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok { + helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "bridge", "provisioning", "prefix") + } else { + helper.Copy(up.Str, "bridge", "provisioning", "prefix") + } + if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); ok && secret != "generate" { + helper.Set(up.Str, secret, "bridge", "provisioning", "shared_secret") + } else if secret, ok = helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" { + sharedSecret := appservice.RandomString(64) + helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret") + } else { + helper.Copy(up.Str, "bridge", "provisioning", "shared_secret") + } helper.Copy(up.Map, "bridge", "permissions") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.Map, "bridge", "relay", "message_formats") - - helper.Copy(up.Str, "logging", "directory") - helper.Copy(up.Str|up.Null, "logging", "file_name_format") - helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format") - helper.Copy(up.Int, "logging", "file_mode") - helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format") - helper.Copy(up.Str, "logging", "print_level") } -func (wau waUpgrader) SpacedBlocks() [][]string { - return [][]string{ - {"homeserver", "asmux"}, - {"appservice"}, - {"appservice", "hostname"}, - {"appservice", "database"}, - {"appservice", "provisioning"}, - {"appservice", "id"}, - {"appservice", "as_token"}, - {"segment_key"}, - {"metrics"}, - {"whatsapp"}, - {"bridge"}, - {"bridge", "command_prefix"}, - {"bridge", "management_room_text"}, - {"bridge", "encryption"}, - {"bridge", "permissions"}, - {"bridge", "relay"}, - {"logging"}, - } -} - -func Mutate(path string, mutate func(helper *up.Helper)) error { - _, _, err := up.Do(path, true, waUpgrader{}, up.SimpleUpgrader(mutate)) - return err -} - -func Upgrade(path string, save bool) ([]byte, bool, error) { - return up.Do(path, save, waUpgrader{}) +var SpacedBlocks = [][]string{ + {"homeserver", "asmux"}, + {"appservice"}, + {"appservice", "hostname"}, + {"appservice", "database"}, + {"appservice", "id"}, + {"appservice", "as_token"}, + {"segment_key"}, + {"metrics"}, + {"whatsapp"}, + {"bridge"}, + {"bridge", "command_prefix"}, + {"bridge", "management_room_text"}, + {"bridge", "encryption"}, + {"bridge", "provisioning"}, + {"bridge", "permissions"}, + {"bridge", "relay"}, + {"logging"}, } diff --git a/crypto.go b/crypto.go deleted file mode 100644 index b43231c..0000000 --- a/crypto.go +++ /dev/null @@ -1,327 +0,0 @@ -// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2020 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -//go:build cgo && !nocrypto - -package main - -import ( - "fmt" - "runtime/debug" - "time" - - "github.com/lib/pq" - - "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - - "maunium.net/go/mautrix-whatsapp/database" -) - -var NoSessionFound = crypto.NoSessionFound - -var levelTrace = maulogger.Level{ - Name: "TRACE", - Severity: -10, - Color: -1, -} - -type CryptoHelper struct { - bridge *Bridge - client *mautrix.Client - mach *crypto.OlmMachine - store *database.SQLCryptoStore - log maulogger.Logger - baseLog maulogger.Logger -} - -func init() { - crypto.PostgresArrayWrapper = pq.Array -} - -func NewCryptoHelper(bridge *Bridge) Crypto { - if !bridge.Config.Bridge.Encryption.Allow { - bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config") - return nil - } - baseLog := bridge.Log.Sub("Crypto") - return &CryptoHelper{ - bridge: bridge, - log: baseLog.Sub("Helper"), - baseLog: baseLog, - } -} - -func (helper *CryptoHelper) Init() error { - helper.log.Debugln("Initializing end-to-bridge encryption...") - - helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain)) - - var err error - helper.client, err = helper.loginBot() - if err != nil { - return err - } - - helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID) - logger := &cryptoLogger{helper.baseLog} - stateStore := &cryptoStateStore{helper.bridge} - helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore) - helper.mach.AllowKeyShare = helper.allowKeyShare - - helper.client.Syncer = &cryptoSyncer{helper.mach} - helper.client.Store = &cryptoClientStore{helper.store} - - return helper.mach.Load() -} - -func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection { - cfg := helper.bridge.Config.Bridge.Encryption.KeySharing - if !cfg.Allow { - return &crypto.KeyShareRejectNoResponse - } else if device.Trust == crypto.TrustStateBlacklisted { - return &crypto.KeyShareRejectBlacklisted - } else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification { - portal := helper.bridge.GetPortalByMXID(info.RoomID) - if portal == nil { - helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID) - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} - } - user := helper.bridge.GetUserByMXID(device.UserID) - // FIXME reimplement IsInPortal - if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ { - helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID) - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} - } - helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID) - return nil - } else { - return &crypto.KeyShareRejectUnverified - } -} - -func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) { - deviceID := helper.store.FindDeviceID() - if len(deviceID) > 0 { - helper.log.Debugln("Found existing device ID for bot in database:", deviceID) - } - client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "") - if err != nil { - return nil, fmt.Errorf("failed to initialize client: %w", err) - } - client.Logger = helper.baseLog.Sub("Bot") - client.Client = helper.bridge.AS.HTTPClient - client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries - flows, err := client.GetLoginFlows() - if err != nil { - return nil, fmt.Errorf("failed to get supported login flows: %w", err) - } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { - return nil, fmt.Errorf("homeserver does not support appservice login") - } - // We set the API token to the AS token here to authenticate the appservice login - // It'll get overridden after the login - client.AccessToken = helper.bridge.AS.Registration.AppToken - resp, err := client.Login(&mautrix.ReqLogin{ - Type: mautrix.AuthTypeAppservice, - Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID())}, - DeviceID: deviceID, - InitialDeviceDisplayName: "WhatsApp Bridge", - StoreCredentials: true, - }) - if err != nil { - return nil, fmt.Errorf("failed to log in as bridge bot: %w", err) - } - helper.store.DeviceID = resp.DeviceID - return client, nil -} - -func (helper *CryptoHelper) Start() { - helper.log.Debugln("Starting syncer for receiving to-device messages") - err := helper.client.Sync() - if err != nil { - helper.log.Errorln("Fatal error syncing:", err) - } else { - helper.log.Infoln("Bridge bot to-device syncer stopped without error") - } -} - -func (helper *CryptoHelper) Stop() { - helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync") - helper.client.StopSync() -} - -func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { - return helper.mach.DecryptMegolmEvent(evt) -} - -func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) { - encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content) - if err != nil { - if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { - return nil, err - } - helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID) - users, err := helper.store.GetRoomMembers(roomID) - if err != nil { - return nil, fmt.Errorf("failed to get room member list: %w", err) - } - err = helper.mach.ShareGroupSession(roomID, users) - if err != nil { - return nil, fmt.Errorf("failed to share group session: %w", err) - } - encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content) - if err != nil { - return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) - } - } - return encrypted, nil -} - -func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { - return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) -} - -func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { - err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) - if err != nil { - helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err) - } else { - helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID) - } -} - -func (helper *CryptoHelper) ResetSession(roomID id.RoomID) { - err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID) - if err != nil { - helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err) - } -} - -func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) { - helper.mach.HandleMemberEvent(evt) -} - -type cryptoSyncer struct { - *crypto.OlmMachine -} - -func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error { - done := make(chan struct{}) - go func() { - defer func() { - if err := recover(); err != nil { - syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack()) - } - done <- struct{}{} - }() - syncer.Log.Trace("Starting sync response handling (%s)", since) - syncer.ProcessSyncResponse(resp, since) - syncer.Log.Trace("Successfully handled sync response (%s)", since) - }() - select { - case <-done: - case <-time.After(30 * time.Second): - syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since) - } - return nil -} - -func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { - syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err) - return 10 * time.Second, nil -} - -func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { - everything := []event.Type{{Type: "*"}} - return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ - IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, - }, - } -} - -type cryptoLogger struct { - int maulogger.Logger -} - -func (c *cryptoLogger) Error(message string, args ...interface{}) { - c.int.Errorfln(message, args...) -} - -func (c *cryptoLogger) Warn(message string, args ...interface{}) { - c.int.Warnfln(message, args...) -} - -func (c *cryptoLogger) Debug(message string, args ...interface{}) { - c.int.Debugfln(message, args...) -} - -func (c *cryptoLogger) Trace(message string, args ...interface{}) { - c.int.Logfln(levelTrace, message, args...) -} - -type cryptoClientStore struct { - int *database.SQLCryptoStore -} - -func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {} -func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" } -func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {} -func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil } - -func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) { - c.int.PutNextBatch(nextBatchToken) -} - -func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string { - return c.int.GetNextBatch() -} - -var _ mautrix.Storer = (*cryptoClientStore)(nil) - -type cryptoStateStore struct { - bridge *Bridge -} - -var _ crypto.StateStore = (*cryptoStateStore)(nil) - -func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { - portal := c.bridge.GetPortalByMXID(id) - if portal != nil { - return portal.Encrypted - } - return false -} - -func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID { - return c.bridge.StateStore.FindSharedRooms(id) -} - -func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { - // TODO implement - return nil -} diff --git a/custompuppet.go b/custompuppet.go index 2284ea9..d3db3e7 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -75,8 +75,8 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) { Type: mautrix.AuthTypePassword, Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, Password: hex.EncodeToString(mac.Sum(nil)), - DeviceID: "WhatsApp Bridge", - InitialDeviceDisplayName: "WhatsApp Bridge", + DeviceID: "WhatsApp bridge", + InitialDeviceDisplayName: "WhatsApp bridge", }) if err != nil { return "", err @@ -84,22 +84,22 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) { return resp.AccessToken, nil } -func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { +func (br *WABridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { _, homeserver, err := mxid.Parse() if err != nil { return nil, err } - homeserverURL, found := bridge.Config.Bridge.DoublePuppetServerMap[homeserver] + homeserverURL, found := br.Config.Bridge.DoublePuppetServerMap[homeserver] if !found { - if homeserver == bridge.AS.HomeserverDomain { - homeserverURL = bridge.AS.HomeserverURL - } else if bridge.Config.Bridge.DoublePuppetAllowDiscovery { + if homeserver == br.AS.HomeserverDomain { + homeserverURL = br.AS.HomeserverURL + } else if br.Config.Bridge.DoublePuppetAllowDiscovery { resp, err := mautrix.DiscoverClientAPI(homeserver) if err != nil { return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) } homeserverURL = resp.Homeserver.BaseURL - bridge.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid) + br.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid) } else { return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) } @@ -108,9 +108,9 @@ func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) if err != nil { return nil, err } - client.Logger = bridge.AS.Log.Sub(mxid.String()) - client.Client = bridge.AS.HTTPClient - client.DefaultHTTPRetries = bridge.AS.DefaultHTTPRetries + client.Logger = br.AS.Log.Sub(mxid.String()) + client.Client = br.AS.HTTPClient + client.DefaultHTTPRetries = br.AS.DefaultHTTPRetries return client, nil } diff --git a/database/backfill.go b/database/backfill.go index 1a0979b..77a84ed 100644 --- a/database/backfill.go +++ b/database/backfill.go @@ -26,7 +26,9 @@ import ( "time" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" ) type BackfillType int @@ -165,7 +167,7 @@ func (b *Backfill) String() string { ) } -func (b *Backfill) Scan(row Scannable) *Backfill { +func (b *Backfill) Scan(row dbutil.Scannable) *Backfill { err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay) if err != nil { if !errors.Is(err, sql.ErrNoRows) { @@ -256,7 +258,7 @@ type BackfillState struct { FirstExpectedTimestamp uint64 } -func (b *BackfillState) Scan(row Scannable) *BackfillState { +func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState { err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp) if err != nil { if !errors.Is(err, sql.ErrNoRows) { diff --git a/database/cryptostore.go b/database/cryptostore.go index ceca60d..fc221a7 100644 --- a/database/cryptostore.go +++ b/database/cryptostore.go @@ -1,18 +1,8 @@ -// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2020 Tulir Asokan +// Copyright (c) 2022 Tulir Asokan // -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. //go:build cgo && !nocrypto @@ -21,8 +11,6 @@ package database import ( "database/sql" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/id" ) @@ -37,11 +25,9 @@ var _ crypto.Store = (*SQLCryptoStore)(nil) func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore { return &SQLCryptoStore{ - SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "", - []byte("maunium.net/go/mautrix-whatsapp"), - &cryptoLogger{db.log.Sub("CryptoStore")}), - UserID: userID, - GhostIDFormat: ghostIDFormat, + SQLCryptoStore: crypto.NewSQLCryptoStore(db.Database, "", "", []byte("maunium.net/go/mautrix-whatsapp")), + UserID: userID, + GhostIDFormat: ghostIDFormat, } } @@ -76,30 +62,3 @@ func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.User } return } - -// TODO merge this with the one in the parent package -type cryptoLogger struct { - int log.Logger -} - -var levelTrace = log.Level{ - Name: "TRACE", - Severity: -10, - Color: -1, -} - -func (c *cryptoLogger) Error(message string, args ...interface{}) { - c.int.Errorfln(message, args...) -} - -func (c *cryptoLogger) Warn(message string, args ...interface{}) { - c.int.Warnfln(message, args...) -} - -func (c *cryptoLogger) Debug(message string, args ...interface{}) { - c.int.Debugfln(message, args...) -} - -func (c *cryptoLogger) Trace(message string, args ...interface{}) { - c.int.Logfln(levelTrace, message, args...) -} diff --git a/database/database.go b/database/database.go index 8fe1150..bf31907 100644 --- a/database/database.go +++ b/database/database.go @@ -17,21 +17,17 @@ package database import ( - "database/sql" "errors" - "fmt" "net" "time" "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - log "maunium.net/go/maulogger/v2" - "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store/sqlstore" - "maunium.net/go/mautrix-whatsapp/config" "maunium.net/go/mautrix-whatsapp/database/upgrades" + "maunium.net/go/mautrix/util/dbutil" ) func init() { @@ -39,9 +35,7 @@ func init() { } type Database struct { - *sql.DB - log log.Logger - dialect string + *dbutil.Database User *UserQuery Portal *PortalQuery @@ -55,79 +49,46 @@ type Database struct { MediaBackfillRequest *MediaBackfillRequestQuery } -func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) { - conn, err := sql.Open(cfg.Type, cfg.URI) - if err != nil { - return nil, err - } - - db := &Database{ - DB: conn, - log: baseLog.Sub("Database"), - dialect: cfg.Type, - } +func New(baseDB *dbutil.Database) *Database { + db := &Database{Database: baseDB} + db.UpgradeTable = upgrades.Table db.User = &UserQuery{ db: db, - log: db.log.Sub("User"), + log: db.Log.Sub("User"), } db.Portal = &PortalQuery{ db: db, - log: db.log.Sub("Portal"), + log: db.Log.Sub("Portal"), } db.Puppet = &PuppetQuery{ db: db, - log: db.log.Sub("Puppet"), + log: db.Log.Sub("Puppet"), } db.Message = &MessageQuery{ db: db, - log: db.log.Sub("Message"), + log: db.Log.Sub("Message"), } db.Reaction = &ReactionQuery{ db: db, - log: db.log.Sub("Reaction"), + log: db.Log.Sub("Reaction"), } db.DisappearingMessage = &DisappearingMessageQuery{ db: db, - log: db.log.Sub("DisappearingMessage"), + log: db.Log.Sub("DisappearingMessage"), } db.Backfill = &BackfillQuery{ db: db, - log: db.log.Sub("Backfill"), + log: db.Log.Sub("Backfill"), } db.HistorySync = &HistorySyncQuery{ db: db, - log: db.log.Sub("HistorySync"), + log: db.Log.Sub("HistorySync"), } db.MediaBackfillRequest = &MediaBackfillRequestQuery{ db: db, - log: db.log.Sub("MediaBackfillRequest"), + log: db.Log.Sub("MediaBackfillRequest"), } - - db.SetMaxOpenConns(cfg.MaxOpenConns) - db.SetMaxIdleConns(cfg.MaxIdleConns) - if len(cfg.ConnMaxIdleTime) > 0 { - maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime) - if err != nil { - return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err) - } - db.SetConnMaxIdleTime(maxIdleTimeDuration) - } - if len(cfg.ConnMaxLifetime) > 0 { - maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime) - if err != nil { - return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err) - } - db.SetConnMaxLifetime(maxLifetimeDuration) - } - return db, nil -} - -func (db *Database) Init() error { - return upgrades.Run(db.log.Sub("Upgrade"), db.dialect, db.DB) -} - -type Scannable interface { - Scan(...interface{}) error + return db } func isRetryableError(err error) bool { @@ -145,7 +106,7 @@ func isRetryableError(err error) bool { } func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) { - if db.dialect != "sqlite" && isRetryableError(err) { + if db.Dialect != dbutil.SQLite && isRetryableError(err) { sleepTime := time.Duration(attemptIndex*2) * time.Second device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime) time.Sleep(sleepTime) diff --git a/database/disappearingmessage.go b/database/disappearingmessage.go index 4dd4d10..90769c1 100644 --- a/database/disappearingmessage.go +++ b/database/disappearingmessage.go @@ -24,6 +24,7 @@ import ( log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" ) type DisappearingMessageQuery struct { @@ -94,7 +95,7 @@ type DisappearingMessage struct { ExpireAt time.Time } -func (msg *DisappearingMessage) Scan(row Scannable) *DisappearingMessage { +func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage { var expireIn int64 var expireAt sql.NullInt64 err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt) diff --git a/database/historysync.go b/database/historysync.go index 4415623..45bdc3d 100644 --- a/database/historysync.go +++ b/database/historysync.go @@ -27,7 +27,9 @@ import ( _ "github.com/mattn/go-sqlite3" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" ) type HistorySyncQuery struct { @@ -139,7 +141,7 @@ func (hsc *HistorySyncConversation) Upsert() { } } -func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation { +func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation { err := row.Scan( &hsc.UserID, &hsc.ConversationID, @@ -166,7 +168,7 @@ func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) { nPtr := &n // Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit. - if n < 0 && hsq.db.dialect == "postgres" { + if n < 0 && hsq.db.Dialect == dbutil.Postgres { nPtr = nil } rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr) diff --git a/database/mediabackfillrequest.go b/database/mediabackfillrequest.go index 6a629c6..c206796 100644 --- a/database/mediabackfillrequest.go +++ b/database/mediabackfillrequest.go @@ -22,7 +22,9 @@ import ( _ "github.com/mattn/go-sqlite3" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" ) type MediaBackfillRequestStatus int @@ -100,7 +102,7 @@ func (mbr *MediaBackfillRequest) Upsert() { } } -func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest { +func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest { err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error) if err != nil { if !errors.Is(err, sql.ErrNoRows) { diff --git a/database/message.go b/database/message.go index 993e9e5..e07c5aa 100644 --- a/database/message.go +++ b/database/message.go @@ -25,6 +25,7 @@ import ( log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "go.mau.fi/whatsmeow/types" ) @@ -163,7 +164,7 @@ func (msg *Message) IsFakeJID() bool { return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID) } -func (msg *Message) Scan(row Scannable) *Message { +func (msg *Message) Scan(row dbutil.Scannable) *Message { var ts int64 err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID) if err != nil { diff --git a/database/portal.go b/database/portal.go index febeb26..cb6c050 100644 --- a/database/portal.go +++ b/database/portal.go @@ -22,6 +22,7 @@ import ( log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "go.mau.fi/whatsmeow/types" ) @@ -152,7 +153,7 @@ type Portal struct { ExpirationTime uint32 } -func (portal *Portal) Scan(row Scannable) *Portal { +func (portal *Portal) Scan(row dbutil.Scannable) *Portal { var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime) if err != nil { diff --git a/database/puppet.go b/database/puppet.go index 57ddf8b..f7261ed 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -20,7 +20,9 @@ import ( "database/sql" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "go.mau.fi/whatsmeow/types" ) @@ -97,7 +99,7 @@ type Puppet struct { EnableReceipts bool } -func (puppet *Puppet) Scan(row Scannable) *Puppet { +func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet { var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString var quality sql.NullInt64 var enablePresence, enableReceipts sql.NullBool diff --git a/database/reaction.go b/database/reaction.go index 1126b25..d15be66 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -23,6 +23,7 @@ import ( log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "go.mau.fi/whatsmeow/types" ) @@ -85,7 +86,7 @@ type Reaction struct { JID types.MessageID } -func (reaction *Reaction) Scan(row Scannable) *Reaction { +func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction { err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { diff --git a/database/statestore.go b/database/statestore.go deleted file mode 100644 index 3a507da..0000000 --- a/database/statestore.go +++ /dev/null @@ -1,282 +0,0 @@ -// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. -// Copyright (C) 2022 Tulir Asokan -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package database - -import ( - "database/sql" - "encoding/json" - "errors" - "sync" - - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type SQLStateStore struct { - *appservice.TypingStateStore - - db *Database - log log.Logger - - Typing map[id.RoomID]map[id.UserID]int64 - typingLock sync.RWMutex -} - -var _ appservice.StateStore = (*SQLStateStore)(nil) - -func NewSQLStateStore(db *Database) *SQLStateStore { - return &SQLStateStore{ - TypingStateStore: appservice.NewTypingStateStore(), - db: db, - log: db.log.Sub("StateStore"), - } -} - -func (store *SQLStateStore) IsRegistered(userID id.UserID) bool { - var isRegistered bool - err := store.db. - QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID). - Scan(&isRegistered) - if err != nil { - store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err) - } - return isRegistered -} - -func (store *SQLStateStore) MarkRegistered(userID id.UserID) { - _, err := store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) - if err != nil { - store.log.Warnfln("Failed to mark %s as registered: %v", userID, err) - } -} - -func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent { - members := make(map[id.UserID]*event.MemberEventContent) - rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) - if err != nil { - return members - } - var userID id.UserID - var member event.MemberEventContent - for rows.Next() { - err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil { - store.log.Warnfln("Failed to scan member in %s: %v", roomID, err) - } else { - members[userID] = &member - } - } - return members -} - -func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { - membership := event.MembershipLeave - err := store.db. - QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). - Scan(&membership) - if err != nil && err != sql.ErrNoRows { - store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err) - } - return membership -} - -func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := store.TryGetMember(roomID, userID) - if !ok { - member.Membership = event.MembershipLeave - } - return member -} - -func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) { - var member event.MemberEventContent - err := store.db. - QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). - Scan(&member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil && err != sql.ErrNoRows { - store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err) - } - return &member, err == nil -} - -func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) { - rows, err := store.db.Query(` - SELECT room_id FROM mx_user_profile - LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id - WHERE user_id=$1 AND portal.encrypted=true - `, userID) - if err != nil { - store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err) - return - } - for rows.Next() { - var roomID id.RoomID - err = rows.Scan(&roomID) - if err != nil { - store.log.Warnfln("Failed to scan room ID: %v", err) - } else { - rooms = append(rooms, roomID) - } - } - return -} - -func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join") -} - -func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join", "invite") -} - -func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - membership := store.GetMembership(roomID, userID) - for _, allowedMembership := range allowedMemberships { - if allowedMembership == membership { - return true - } - } - return false -} - -func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { - _, err := store.db.Exec(` - INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) - ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership - `, roomID, userID, membership) - if err != nil { - store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err) - } -} - -func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { - _, err := store.db.Exec(` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url - `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) - if err != nil { - store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err) - } -} - -func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { - levelsBytes, err := json.Marshal(levels) - if err != nil { - store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err) - return - } - _, err = store.db.Exec(` - INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) - ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels - `, roomID, levelsBytes) - if err != nil { - store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err) - } -} - -func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { - var data []byte - err := store.db. - QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&data) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err) - } - return - } - levels = &event.PowerLevelsEventContent{} - err = json.Unmarshal(data, levels) - if err != nil { - store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err) - return nil - } - return -} - -func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { - if store.db.dialect == "postgres" { - var powerLevel int - err := store.db. - QueryRow(` - SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - FROM mx_room_state WHERE room_id=$1 - `, roomID, userID). - Scan(&powerLevel) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err) - } - return powerLevel - } - return store.GetPowerLevels(roomID).GetUserLevel(userID) -} - -func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { - if store.db.dialect == "postgres" { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var powerLevel int - err := store.db. - QueryRow(` - SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) - FROM mx_room_state WHERE room_id=$1 - `, roomID, eventType.Type, defaultType, defaultValue). - Scan(&powerLevel) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue - } - return powerLevel - } - return store.GetPowerLevels(roomID).GetEventLevel(eventType) -} - -func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { - if store.db.dialect == "postgres" { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var hasPower bool - err := store.db. - QueryRow(`SELECT - COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - >= - COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) - FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). - Scan(&hasPower) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue == 0 - } - return hasPower - } - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) -} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql new file mode 100644 index 0000000..701108d --- /dev/null +++ b/database/upgrades/00-latest-revision.sql @@ -0,0 +1,181 @@ +-- v0 -> v48: Latest revision + +CREATE TABLE "user" ( + mxid TEXT PRIMARY KEY, + username TEXT UNIQUE, + agent SMALLINT, + device SMALLINT, + + management_room TEXT, + space_room TEXT, + + phone_last_seen BIGINT, + phone_last_pinged BIGINT, + + timezone TEXT +); + +CREATE TABLE portal ( + jid TEXT, + receiver TEXT, + mxid TEXT UNIQUE, + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar TEXT NOT NULL, + avatar_url TEXT, + encrypted BOOLEAN NOT NULL DEFAULT false, + + first_event_id TEXT, + next_batch_id TEXT, + relay_user_id TEXT, + expiration_time BIGINT NOT NULL DEFAULT 0, + + PRIMARY KEY (jid, receiver) +); + +CREATE TABLE puppet ( + username TEXT PRIMARY KEY, + displayname TEXT, + name_quality SMALLINT, + avatar TEXT, + avatar_url TEXT, + + custom_mxid TEXT, + access_token TEXT, + next_batch TEXT, + + enable_presence BOOLEAN NOT NULL DEFAULT true, + enable_receipts BOOLEAN NOT NULL DEFAULT true +); + +-- only: postgres +CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found'); + +CREATE TABLE message ( + chat_jid TEXT, + chat_receiver TEXT, + jid TEXT, + mxid TEXT UNIQUE, + sender TEXT, + timestamp BIGINT, + sent BOOLEAN, + error error_type, + type TEXT, + + broadcast_list_jid TEXT, + + PRIMARY KEY (chat_jid, chat_receiver, jid), + FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE +); + +CREATE TABLE reaction ( + chat_jid TEXT, + chat_receiver TEXT, + target_jid TEXT, + sender TEXT, + + mxid TEXT NOT NULL, + jid TEXT NOT NULL, + + PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender), + FOREIGN KEY (chat_jid, chat_receiver, target_jid) REFERENCES message(chat_jid, chat_receiver, jid) + ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE disappearing_message ( + room_id TEXT, + event_id TEXT, + expire_in BIGINT NOT NULL, + expire_at BIGINT, + PRIMARY KEY (room_id, event_id) +); + +CREATE TABLE user_portal ( + user_mxid TEXT, + portal_jid TEXT, + portal_receiver TEXT, + last_read_ts BIGINT NOT NULL DEFAULT 0, + in_space BOOLEAN NOT NULL DEFAULT false, + PRIMARY KEY (user_mxid, portal_jid, portal_receiver), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE backfill_queue ( + queue_id INTEGER PRIMARY KEY + -- only: postgres + GENERATED ALWAYS AS IDENTITY + , + user_mxid TEXT, + type INTEGER NOT NULL, + priority INTEGER NOT NULL, + portal_jid TEXT, + portal_receiver TEXT, + time_start TIMESTAMP, + dispatch_time TIMESTAMP, + completed_at TIMESTAMP, + batch_delay INTEGER, + max_batch_events INTEGER NOT NULL, + max_total_events INTEGER, + + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE backfill_state ( + user_mxid TEXT, + portal_jid TEXT, + portal_receiver TEXT, + processing_batch BOOLEAN, + backfill_complete BOOLEAN, + first_expected_ts TIMESTAMP, + PRIMARY KEY (user_mxid, portal_jid, portal_receiver), + FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE +); + +CREATE TABLE media_backfill_requests ( + user_mxid TEXT, + portal_jid TEXT, + portal_receiver TEXT, + event_id TEXT, + media_key bytea, + status INTEGER, + error TEXT, + PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE history_sync_conversation ( + user_mxid TEXT, + conversation_id TEXT, + portal_jid TEXT, + portal_receiver TEXT, + + last_message_timestamp TIMESTAMP, + archived BOOLEAN, + pinned INTEGER, + mute_end_time TIMESTAMP, + disappearing_mode INTEGER, + end_of_history_transfer_type INTEGER, + ephemeral_Expiration INTEGER, + marked_as_unread BOOLEAN, + unread_count INTEGER, + + PRIMARY KEY (user_mxid, conversation_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE history_sync_message ( + user_mxid TEXT, + conversation_id TEXT, + message_id TEXT, + timestamp TIMESTAMP, + data bytea, + inserted_time TIMESTAMP, + + PRIMARY KEY (user_mxid, conversation_id, message_id), + FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE +); diff --git a/database/upgrades/2018-09-01-initial-schema.go b/database/upgrades/2018-09-01-initial-schema.go deleted file mode 100644 index f142726..0000000 --- a/database/upgrades/2018-09-01-initial-schema.go +++ /dev/null @@ -1,67 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal ( - jid VARCHAR(255), - receiver VARCHAR(255), - mxid VARCHAR(255) UNIQUE, - - name VARCHAR(255) NOT NULL, - topic VARCHAR(255) NOT NULL, - avatar VARCHAR(255) NOT NULL, - - PRIMARY KEY (jid, receiver) - )`) - if err != nil { - return err - } - - _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet ( - jid VARCHAR(255) PRIMARY KEY, - avatar VARCHAR(255), - displayname VARCHAR(255), - name_quality SMALLINT - )`) - if err != nil { - return err - } - - _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" ( - mxid VARCHAR(255) PRIMARY KEY, - jid VARCHAR(255) UNIQUE, - - management_room VARCHAR(255), - - client_id VARCHAR(255), - client_token VARCHAR(255), - server_token VARCHAR(255), - enc_key bytea, - mac_key bytea - )`) - if err != nil { - return err - } - - _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message ( - chat_jid VARCHAR(255), - chat_receiver VARCHAR(255), - jid VARCHAR(255), - mxid VARCHAR(255) NOT NULL UNIQUE, - sender VARCHAR(255) NOT NULL, - content bytea NOT NULL, - - PRIMARY KEY (chat_jid, chat_receiver, jid), - FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - )`) - if err != nil { - return err - } - - return nil - }} -} diff --git a/database/upgrades/2019-05-21-message-timestamp-column.go b/database/upgrades/2019-05-21-message-timestamp-column.go deleted file mode 100644 index cb93614..0000000 --- a/database/upgrades/2019-05-21-message-timestamp-column.go +++ /dev/null @@ -1,15 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[2] = upgrade{"Add timestamp column to messages", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0") - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2019-05-22-user-last-connection-column.go b/database/upgrades/2019-05-22-user-last-connection-column.go deleted file mode 100644 index 3e1a236..0000000 --- a/database/upgrades/2019-05-22-user-last-connection-column.go +++ /dev/null @@ -1,15 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[3] = upgrade{"Add last_connection column to users", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2019-05-23-puppet-custom-mxid-columns.go b/database/upgrades/2019-05-23-puppet-custom-mxid-columns.go deleted file mode 100644 index 2f17154..0000000 --- a/database/upgrades/2019-05-23-puppet-custom-mxid-columns.go +++ /dev/null @@ -1,23 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[5] = upgrade{"Add columns to store custom puppet info", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN custom_mxid VARCHAR(255)`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN access_token VARCHAR(1023)`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN next_batch VARCHAR(255)`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2019-05-28-user-portal-table.go b/database/upgrades/2019-05-28-user-portal-table.go deleted file mode 100644 index 18d8550..0000000 --- a/database/upgrades/2019-05-28-user-portal-table.go +++ /dev/null @@ -1,19 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[6] = upgrade{"Add user-portal mapping table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`CREATE TABLE user_portal ( - user_jid VARCHAR(255), - portal_jid VARCHAR(255), - portal_receiver VARCHAR(255), - PRIMARY KEY (user_jid, portal_jid, portal_receiver), - FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - )`) - return err - }} -} diff --git a/database/upgrades/2019-06-01-avatar-url-fields.go b/database/upgrades/2019-06-01-avatar-url-fields.go deleted file mode 100644 index 938b291..0000000 --- a/database/upgrades/2019-06-01-avatar-url-fields.go +++ /dev/null @@ -1,19 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2019-08-10-portal-in-community-field.go b/database/upgrades/2019-08-10-portal-in-community-field.go deleted file mode 100644 index 44893fd..0000000 --- a/database/upgrades/2019-08-10-portal-in-community-field.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[8] = upgrade{"Add columns to store portal in filtering community meta", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_community BOOLEAN NOT NULL DEFAULT FALSE`) - return err - }} -} diff --git a/database/upgrades/2019-08-25-move-state-store-to-db.go b/database/upgrades/2019-08-25-move-state-store-to-db.go deleted file mode 100644 index b5201e1..0000000 --- a/database/upgrades/2019-08-25-move-state-store-to-db.go +++ /dev/null @@ -1,39 +0,0 @@ -package upgrades - -import ( - "database/sql" - "strings" -) - -func init() { - userProfileTable := `CREATE TABLE mx_user_profile ( - room_id VARCHAR(255), - user_id VARCHAR(255), - membership VARCHAR(15) NOT NULL, - PRIMARY KEY (room_id, user_id) - )` - - roomStateTable := `CREATE TABLE mx_room_state ( - room_id VARCHAR(255) PRIMARY KEY, - power_levels TEXT - )` - - registrationsTable := `CREATE TABLE mx_registrations ( - user_id VARCHAR(255) PRIMARY KEY - )` - - upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == Postgres { - roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1) - } - - if _, err := tx.Exec(userProfileTable); err != nil { - return err - } else if _, err = tx.Exec(roomStateTable); err != nil { - return err - } else if _, err = tx.Exec(registrationsTable); err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2019-11-10-full-member-state-store.go b/database/upgrades/2019-11-10-full-member-state-store.go deleted file mode 100644 index 4040e7f..0000000 --- a/database/upgrades/2019-11-10-full-member-state-store.go +++ /dev/null @@ -1,16 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`) - return err - }} -} diff --git a/database/upgrades/2019-11-12-fix-room-topic-length.go b/database/upgrades/2019-11-12-fix-room-topic-length.go deleted file mode 100644 index 3532d35..0000000 --- a/database/upgrades/2019-11-12-fix-room-topic-length.go +++ /dev/null @@ -1,16 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[11] = upgrade{"Adjust the length of column topic in portal", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == SQLite { - // SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway. - return nil - } - _, err := tx.Exec(`ALTER TABLE portal ALTER COLUMN topic TYPE VARCHAR(512)`) - return err - }} -} diff --git a/database/upgrades/2020-05-09-add-portal-encrypted-field.go b/database/upgrades/2020-05-09-add-portal-encrypted-field.go deleted file mode 100644 index ef0f963..0000000 --- a/database/upgrades/2020-05-09-add-portal-encrypted-field.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`) - return err - }} -} diff --git a/database/upgrades/2020-05-09-crypto-store.go b/database/upgrades/2020-05-09-crypto-store.go deleted file mode 100644 index 8be6cd8..0000000 --- a/database/upgrades/2020-05-09-crypto-store.go +++ /dev/null @@ -1,73 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`CREATE TABLE crypto_account ( - device_id VARCHAR(255) PRIMARY KEY, - shared BOOLEAN NOT NULL, - sync_token TEXT NOT NULL, - account bytea NOT NULL - )`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE crypto_message_index ( - sender_key CHAR(43), - session_id CHAR(43), - "index" INTEGER, - event_id VARCHAR(255) NOT NULL, - timestamp BIGINT NOT NULL, - - PRIMARY KEY (sender_key, session_id, "index") - )`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE crypto_tracked_user ( - user_id VARCHAR(255) PRIMARY KEY - )`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE crypto_device ( - user_id VARCHAR(255), - device_id VARCHAR(255), - identity_key CHAR(43) NOT NULL, - signing_key CHAR(43) NOT NULL, - trust SMALLINT NOT NULL, - deleted BOOLEAN NOT NULL, - name VARCHAR(255) NOT NULL, - - PRIMARY KEY (user_id, device_id) - )`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE crypto_olm_session ( - session_id CHAR(43) PRIMARY KEY, - sender_key CHAR(43) NOT NULL, - session bytea NOT NULL, - created_at timestamp NOT NULL, - last_used timestamp NOT NULL - )`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session ( - session_id CHAR(43) PRIMARY KEY, - sender_key CHAR(43) NOT NULL, - signing_key CHAR(43) NOT NULL, - room_id VARCHAR(255) NOT NULL, - session bytea NOT NULL, - forwarding_chains bytea NOT NULL - )`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2020-05-12-outbound-group-session-store.go b/database/upgrades/2020-05-12-outbound-group-session-store.go deleted file mode 100644 index 0f108a6..0000000 --- a/database/upgrades/2020-05-12-outbound-group-session-store.go +++ /dev/null @@ -1,25 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session ( - room_id VARCHAR(255) PRIMARY KEY, - session_id CHAR(43) NOT NULL UNIQUE, - session bytea NOT NULL, - shared BOOLEAN NOT NULL, - max_messages INTEGER NOT NULL, - message_count INTEGER NOT NULL, - max_age BIGINT NOT NULL, - created_at timestamp NOT NULL, - last_used timestamp NOT NULL - )`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2020-07-10-custom-puppet-presence-toggle.go b/database/upgrades/2020-07-10-custom-puppet-presence-toggle.go deleted file mode 100644 index 9eddbce..0000000 --- a/database/upgrades/2020-07-10-custom-puppet-presence-toggle.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[15] = upgrade{"Add enable_presence column for puppets", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_presence BOOLEAN NOT NULL DEFAULT true`) - return err - }} -} diff --git a/database/upgrades/2020-07-10-update-crypto-store.go b/database/upgrades/2020-07-10-update-crypto-store.go deleted file mode 100644 index e33b20f..0000000 --- a/database/upgrades/2020-07-10-update-crypto-store.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "maunium.net/go/mautrix/crypto/sql_store_upgrade" -) - -func init() { - upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error { - return sql_store_upgrade.Upgrades[1](tx, c.dialect.String()) - }} -} diff --git a/database/upgrades/2020-07-10-x-custom-puppet-receipts-toggle.go b/database/upgrades/2020-07-10-x-custom-puppet-receipts-toggle.go deleted file mode 100644 index 28a54f7..0000000 --- a/database/upgrades/2020-07-10-x-custom-puppet-receipts-toggle.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[17] = upgrade{"Add enable_receipts column for puppets", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_receipts BOOLEAN NOT NULL DEFAULT true`) - return err - }} -} diff --git a/database/upgrades/2020-08-03-update-crypto-store.go b/database/upgrades/2020-08-03-update-crypto-store.go deleted file mode 100644 index 859c250..0000000 --- a/database/upgrades/2020-08-03-update-crypto-store.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "maunium.net/go/mautrix/crypto/sql_store_upgrade" -) - -func init() { - upgrades[18] = upgrade{"Add megolm withheld data to crypto store", func(tx *sql.Tx, c context) error { - return sql_store_upgrade.Upgrades[2](tx, c.dialect.String()) - }} -} diff --git a/database/upgrades/2020-10-28-crypto-store-cross-signing.go b/database/upgrades/2020-10-28-crypto-store-cross-signing.go deleted file mode 100644 index 5b7aab1..0000000 --- a/database/upgrades/2020-10-28-crypto-store-cross-signing.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "maunium.net/go/mautrix/crypto/sql_store_upgrade" -) - -func init() { - upgrades[19] = upgrade{"Add cross-signing keys to crypto store", func(tx *sql.Tx, c context) error { - return sql_store_upgrade.Upgrades[3](tx, c.dialect.String()) - }} -} diff --git a/database/upgrades/2021-02-17-message-sent-status.go b/database/upgrades/2021-02-17-message-sent-status.go deleted file mode 100644 index a5852b0..0000000 --- a/database/upgrades/2021-02-17-message-sent-status.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`) - return err - }} -} diff --git a/database/upgrades/2021-08-19-remove-message-content.go b/database/upgrades/2021-08-19-remove-message-content.go deleted file mode 100644 index c3c7611..0000000 --- a/database/upgrades/2021-08-19-remove-message-content.go +++ /dev/null @@ -1,44 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[21] = upgrade{"Remove message content from local database", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == SQLite { - _, err := tx.Exec("ALTER TABLE message RENAME TO old_message") - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message ( - chat_jid TEXT, - chat_receiver TEXT, - jid TEXT, - mxid TEXT NOT NULL UNIQUE, - sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - sent BOOLEAN NOT NULL, - - PRIMARY KEY (chat_jid, chat_receiver, jid), - FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - )`) - if err != nil { - return err - } - _, err = tx.Exec("INSERT INTO message SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM old_message") - return err - } else { - _, err := tx.Exec(`ALTER TABLE message DROP COLUMN content`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE message ALTER COLUMN timestamp DROP DEFAULT`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE message ALTER COLUMN sent DROP DEFAULT`) - return err - } - }} -} diff --git a/database/upgrades/2021-08-19-varchar-to-text-crypto.go b/database/upgrades/2021-08-19-varchar-to-text-crypto.go deleted file mode 100644 index 39f5a09..0000000 --- a/database/upgrades/2021-08-19-varchar-to-text-crypto.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "maunium.net/go/mautrix/crypto/sql_store_upgrade" -) - -func init() { - upgrades[23] = upgrade{"Replace VARCHAR(255) with TEXT in the crypto database", func(tx *sql.Tx, ctx context) error { - return sql_store_upgrade.Upgrades[4](tx, ctx.dialect.String()) - }} -} diff --git a/database/upgrades/2021-08-19-varchar-to-text.go b/database/upgrades/2021-08-19-varchar-to-text.go deleted file mode 100644 index 9fcd3ae..0000000 --- a/database/upgrades/2021-08-19-varchar-to-text.go +++ /dev/null @@ -1,48 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[22] = upgrade{"Replace VARCHAR(255) with TEXT in the database", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == SQLite { - // SQLite doesn't enforce varchar sizes anyway - return nil - } - return execMany(tx, - `ALTER TABLE message ALTER COLUMN chat_jid TYPE TEXT`, - `ALTER TABLE message ALTER COLUMN chat_receiver TYPE TEXT`, - `ALTER TABLE message ALTER COLUMN jid TYPE TEXT`, - `ALTER TABLE message ALTER COLUMN mxid TYPE TEXT`, - `ALTER TABLE message ALTER COLUMN sender TYPE TEXT`, - - `ALTER TABLE portal ALTER COLUMN jid TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN receiver TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN mxid TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN name TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN topic TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN avatar TYPE TEXT`, - `ALTER TABLE portal ALTER COLUMN avatar_url TYPE TEXT`, - - `ALTER TABLE puppet ALTER COLUMN jid TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN avatar TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN displayname TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN custom_mxid TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN access_token TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN next_batch TYPE TEXT`, - `ALTER TABLE puppet ALTER COLUMN avatar_url TYPE TEXT`, - - `ALTER TABLE "user" ALTER COLUMN mxid TYPE TEXT`, - `ALTER TABLE "user" ALTER COLUMN jid TYPE TEXT`, - `ALTER TABLE "user" ALTER COLUMN management_room TYPE TEXT`, - `ALTER TABLE "user" ALTER COLUMN client_id TYPE TEXT`, - `ALTER TABLE "user" ALTER COLUMN client_token TYPE TEXT`, - `ALTER TABLE "user" ALTER COLUMN server_token TYPE TEXT`, - - `ALTER TABLE user_portal ALTER COLUMN user_jid TYPE TEXT`, - `ALTER TABLE user_portal ALTER COLUMN portal_jid TYPE TEXT`, - `ALTER TABLE user_portal ALTER COLUMN portal_receiver TYPE TEXT`, - ) - }} -} diff --git a/database/upgrades/2021-10-21-add-whatsmeow-store.go b/database/upgrades/2021-10-21-add-whatsmeow-store.go deleted file mode 100644 index 41d7a8e..0000000 --- a/database/upgrades/2021-10-21-add-whatsmeow-store.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "go.mau.fi/whatsmeow/store/sqlstore" -) - -func init() { - upgrades[24] = upgrade{"Add whatsmeow state store", func(tx *sql.Tx, ctx context) error { - return sqlstore.Upgrades[0](tx, sqlstore.NewWithDB(ctx.db, ctx.dialect.String(), nil)) - }} -} diff --git a/database/upgrades/2021-10-21-multidevice-updates.go b/database/upgrades/2021-10-21-multidevice-updates.go deleted file mode 100644 index 20e9b62..0000000 --- a/database/upgrades/2021-10-21-multidevice-updates.go +++ /dev/null @@ -1,93 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[25] = upgrade{"Update things for multidevice", func(tx *sql.Tx, ctx context) error { - // This is probably not necessary - _, err := tx.Exec("DROP TABLE user_portal") - if err != nil { - return err - } - - // Remove invalid puppet rows - _, err = tx.Exec("DELETE FROM puppet WHERE jid LIKE '%@g.us' OR jid LIKE '%@broadcast'") - if err != nil { - return err - } - // Remove the suffix from puppets since they'll all have the same suffix - _, err = tx.Exec("UPDATE puppet SET jid=REPLACE(jid, '@s.whatsapp.net', '')") - if err != nil { - return err - } - // Rename column to correctly represent the new content - _, err = tx.Exec("ALTER TABLE puppet RENAME COLUMN jid TO username") - if err != nil { - return err - } - - if ctx.dialect == SQLite { - // Message content was removed from the main message table earlier, but the backup table still exists for SQLite - _, err = tx.Exec("DROP TABLE IF EXISTS old_message") - - _, err = tx.Exec(`ALTER TABLE "user" RENAME TO old_user`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE "user" ( - mxid TEXT PRIMARY KEY, - username TEXT UNIQUE, - agent SMALLINT, - device SMALLINT, - management_room TEXT - )`) - if err != nil { - return err - } - - // No need to copy auth data, users need to relogin anyway - _, err = tx.Exec(`INSERT INTO "user" (mxid, management_room) SELECT mxid, management_room FROM old_user`) - if err != nil { - return err - } - - _, err = tx.Exec("DROP TABLE old_user") - if err != nil { - return err - } - } else { - // The jid column never actually contained the full JID, so let's rename it. - _, err = tx.Exec(`ALTER TABLE "user" RENAME COLUMN jid TO username`) - if err != nil { - return err - } - - // The auth data is now in the whatsmeow_device table. - for _, column := range []string{"last_connection", "client_id", "client_token", "server_token", "enc_key", "mac_key"} { - _, err = tx.Exec(`ALTER TABLE "user" DROP COLUMN ` + column) - if err != nil { - return err - } - } - - // The whatsmeow_device table is keyed by the full JID, so we need to store the other parts of the JID here too. - _, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN agent SMALLINT`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN device SMALLINT`) - if err != nil { - return err - } - - // Clear all usernames, the users need to relogin anyway. - _, err = tx.Exec(`UPDATE "user" SET username=null`) - if err != nil { - return err - } - } - return nil - }} -} diff --git a/database/upgrades/2021-10-26-portal-origin-event-id.go b/database/upgrades/2021-10-26-portal-origin-event-id.go deleted file mode 100644 index 37b8908..0000000 --- a/database/upgrades/2021-10-26-portal-origin-event-id.go +++ /dev/null @@ -1,19 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[26] = upgrade{"Add columns to store infinite backfill pointers for portals", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE portal ADD COLUMN first_event_id TEXT NOT NULL DEFAULT ''`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE portal ADD COLUMN next_batch_id TEXT NOT NULL DEFAULT ''`) - if err != nil { - return err - } - return nil - }} -} diff --git a/database/upgrades/2021-10-27-message-decryption-errors.go b/database/upgrades/2021-10-27-message-decryption-errors.go deleted file mode 100644 index 288709e..0000000 --- a/database/upgrades/2021-10-27-message-decryption-errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`) - return err - }} -} diff --git a/database/upgrades/2021-10-28-portal-relay-user.go b/database/upgrades/2021-10-28-portal-relay-user.go deleted file mode 100644 index 81beedc..0000000 --- a/database/upgrades/2021-10-28-portal-relay-user.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[28] = upgrade{"Add relay user field to portal table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE portal ADD COLUMN relay_user_id TEXT`) - return err - }} -} diff --git a/database/upgrades/2021-10-30-varchar-to-text-state-store.go b/database/upgrades/2021-10-30-varchar-to-text-state-store.go deleted file mode 100644 index 114bf12..0000000 --- a/database/upgrades/2021-10-30-varchar-to-text-state-store.go +++ /dev/null @@ -1,22 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[29] = upgrade{"Replace VARCHAR(255) with TEXT in the Matrix state store", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == SQLite { - // SQLite doesn't enforce varchar sizes anyway - return nil - } - return execMany(tx, - `ALTER TABLE mx_registrations ALTER COLUMN user_id TYPE TEXT`, - `ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT`, - `ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT`, - `ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT`, - `ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT`, - `ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT`, - ) - }} -} diff --git a/database/upgrades/2021-11-30-store-last-read-state.go b/database/upgrades/2021-11-30-store-last-read-state.go deleted file mode 100644 index b5c3ec6..0000000 --- a/database/upgrades/2021-11-30-store-last-read-state.go +++ /dev/null @@ -1,22 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[30] = upgrade{"Store last read message timestamp in database", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`CREATE TABLE user_portal ( - user_mxid TEXT, - portal_jid TEXT, - portal_receiver TEXT, - - last_read_ts BIGINT NOT NULL DEFAULT 0, - - PRIMARY KEY (user_mxid, portal_jid, portal_receiver), - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE - )`) - return err - }} -} diff --git a/database/upgrades/2021-12-22-crypto-store-last-used.go b/database/upgrades/2021-12-22-crypto-store-last-used.go deleted file mode 100644 index 534fe0a..0000000 --- a/database/upgrades/2021-12-22-crypto-store-last-used.go +++ /dev/null @@ -1,13 +0,0 @@ -package upgrades - -import ( - "database/sql" - - "maunium.net/go/mautrix/crypto/sql_store_upgrade" -) - -func init() { - upgrades[31] = upgrade{"Split last_used into last_encrypted and last_decrypted in crypto store", func(tx *sql.Tx, c context) error { - return sql_store_upgrade.Upgrades[5](tx, c.dialect.String()) - }} -} diff --git a/database/upgrades/2021-12-25-broadcast-list-message-source.go b/database/upgrades/2021-12-25-broadcast-list-message-source.go deleted file mode 100644 index 11059de..0000000 --- a/database/upgrades/2021-12-25-broadcast-list-message-source.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`) - return err - }} -} diff --git a/database/upgrades/2021-12-29-personal-filtering-spaces.go b/database/upgrades/2021-12-29-personal-filtering-spaces.go deleted file mode 100644 index 4dbf3df..0000000 --- a/database/upgrades/2021-12-29-personal-filtering-spaces.go +++ /dev/null @@ -1,16 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[33] = upgrade{"Add personal filtering space info to user tables", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN space_room TEXT NOT NULL DEFAULT ''`) - if err != nil { - return err - } - _, err = tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_space BOOLEAN NOT NULL DEFAULT false`) - return err - }} -} diff --git a/database/upgrades/2022-01-07-disappearing-messages.go b/database/upgrades/2022-01-07-disappearing-messages.go deleted file mode 100644 index 888597f..0000000 --- a/database/upgrades/2022-01-07-disappearing-messages.go +++ /dev/null @@ -1,20 +0,0 @@ -package upgrades - -import "database/sql" - -func init() { - upgrades[34] = upgrade{"Add support for disappearing messages", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE portal ADD COLUMN expiration_time BIGINT NOT NULL DEFAULT 0 CHECK (expiration_time >= 0 AND expiration_time < 4294967296)`) - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE disappearing_message ( - room_id TEXT, - event_id TEXT, - expire_in BIGINT NOT NULL, - expire_at BIGINT, - PRIMARY KEY (room_id, event_id) - )`) - return err - }} -} diff --git a/database/upgrades/2022-01-24-phone-last-seen-ts.go b/database/upgrades/2022-01-24-phone-last-seen-ts.go deleted file mode 100644 index 53403a2..0000000 --- a/database/upgrades/2022-01-24-phone-last-seen-ts.go +++ /dev/null @@ -1,10 +0,0 @@ -package upgrades - -import "database/sql" - -func init() { - upgrades[35] = upgrade{"Store approximate last seen timestamp of the main device", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_seen BIGINT`) - return err - }} -} diff --git a/database/upgrades/2022-02-10-message-error-string.go b/database/upgrades/2022-02-10-message-error-string.go deleted file mode 100644 index d3fcc6c..0000000 --- a/database/upgrades/2022-02-10-message-error-string.go +++ /dev/null @@ -1,30 +0,0 @@ -package upgrades - -import "database/sql" - -func init() { - upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error { - if ctx.dialect == Postgres { - _, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')") - if err != nil { - return err - } - } - _, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''") - if err != nil { - return err - } - _, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true") - if err != nil { - return err - } - if ctx.dialect == Postgres { - // TODO do this on sqlite at some point - _, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error") - if err != nil { - return err - } - } - return nil - }} -} diff --git a/database/upgrades/2022-02-18-phone-ping-ts.go b/database/upgrades/2022-02-18-phone-ping-ts.go deleted file mode 100644 index d5061bc..0000000 --- a/database/upgrades/2022-02-18-phone-ping-ts.go +++ /dev/null @@ -1,10 +0,0 @@ -package upgrades - -import "database/sql" - -func init() { - upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`) - return err - }} -} diff --git a/database/upgrades/2022-03-05-reactions.go b/database/upgrades/2022-03-05-reactions.go deleted file mode 100644 index 2e13e04..0000000 --- a/database/upgrades/2022-03-05-reactions.go +++ /dev/null @@ -1,39 +0,0 @@ -package upgrades - -import "database/sql" - -func init() { - upgrades[38] = upgrade{"Add support for reactions", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE message ADD COLUMN type TEXT NOT NULL DEFAULT 'message'`) - if err != nil { - return err - } - if ctx.dialect == Postgres { - _, err = tx.Exec("ALTER TABLE message ALTER COLUMN type DROP DEFAULT") - if err != nil { - return err - } - } - _, err = tx.Exec("UPDATE message SET type='' WHERE error='decryption_failed'") - if err != nil { - return err - } - _, err = tx.Exec("UPDATE message SET type='fake' WHERE jid LIKE 'FAKE::%' OR mxid LIKE 'net.maunium.whatsapp.fake::%' OR jid=mxid") - if err != nil { - return err - } - _, err = tx.Exec(`CREATE TABLE reaction ( - chat_jid TEXT, - chat_receiver TEXT, - target_jid TEXT, - sender TEXT, - mxid TEXT NOT NULL, - jid TEXT NOT NULL, - PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender), - CONSTRAINT target_message_fkey FOREIGN KEY (chat_jid, chat_receiver, target_jid) - REFERENCES message(chat_jid, chat_receiver, jid) - ON DELETE CASCADE ON UPDATE CASCADE - )`) - return err - }} -} diff --git a/database/upgrades/2022-03-15-prioritized-backfill.go b/database/upgrades/2022-03-15-prioritized-backfill.go deleted file mode 100644 index c7a3892..0000000 --- a/database/upgrades/2022-03-15-prioritized-backfill.go +++ /dev/null @@ -1,45 +0,0 @@ -package upgrades - -import ( - "database/sql" - "fmt" -) - -func init() { - upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error { - // The queue_id needs to auto-increment every insertion. For SQLite, - // INTEGER PRIMARY KEY is an alias for the ROWID, so it will - // auto-increment. See https://sqlite.org/lang_createtable.html#rowid - // For Postgres, we need to add GENERATED ALWAYS AS IDENTITY for the - // same functionality. - queueIDColumnTypeModifier := "" - if ctx.dialect == Postgres { - queueIDColumnTypeModifier = "GENERATED ALWAYS AS IDENTITY" - } - - _, err := tx.Exec(fmt.Sprintf(` - CREATE TABLE backfill_queue ( - queue_id INTEGER PRIMARY KEY %s, - user_mxid TEXT, - type INTEGER NOT NULL, - priority INTEGER NOT NULL, - portal_jid TEXT, - portal_receiver TEXT, - time_start TIMESTAMP, - time_end TIMESTAMP, - max_batch_events INTEGER NOT NULL, - max_total_events INTEGER, - batch_delay INTEGER, - completed_at TIMESTAMP, - - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - ) - `, queueIDColumnTypeModifier)) - if err != nil { - return err - } - - return err - }} -} diff --git a/database/upgrades/2022-03-18-historysync-store.go b/database/upgrades/2022-03-18-historysync-store.go deleted file mode 100644 index 3625069..0000000 --- a/database/upgrades/2022-03-18-historysync-store.go +++ /dev/null @@ -1,52 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(` - CREATE TABLE history_sync_conversation ( - user_mxid TEXT, - conversation_id TEXT, - portal_jid TEXT, - portal_receiver TEXT, - last_message_timestamp TIMESTAMP, - archived BOOLEAN, - pinned INTEGER, - mute_end_time TIMESTAMP, - disappearing_mode INTEGER, - end_of_history_transfer_type INTEGER, - ephemeral_expiration INTEGER, - marked_as_unread BOOLEAN, - unread_count INTEGER, - - PRIMARY KEY (user_mxid, conversation_id), - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE - ) - `) - if err != nil { - return err - } - _, err = tx.Exec(` - CREATE TABLE history_sync_message ( - user_mxid TEXT, - conversation_id TEXT, - message_id TEXT, - timestamp TIMESTAMP, - data BYTEA, - - PRIMARY KEY (user_mxid, conversation_id, message_id), - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE - ) - `) - if err != nil { - return err - } - - return nil - }} -} diff --git a/database/upgrades/2022-04-29-backfillqueue-type-order.go b/database/upgrades/2022-04-29-backfillqueue-type-order.go deleted file mode 100644 index dfb4e8e..0000000 --- a/database/upgrades/2022-04-29-backfillqueue-type-order.go +++ /dev/null @@ -1,20 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(` - UPDATE backfill_queue - SET type=CASE - WHEN type=1 THEN 200 - WHEN type=2 THEN 300 - ELSE type - END - WHERE type=1 OR type=2 - `) - return err - }} -} diff --git a/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go b/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go deleted file mode 100644 index 2470ffa..0000000 --- a/database/upgrades/2022-05-09-media-backfill-requests-queue-table.go +++ /dev/null @@ -1,26 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(` - CREATE TABLE media_backfill_requests ( - user_mxid TEXT, - portal_jid TEXT, - portal_receiver TEXT, - event_id TEXT, - media_key BYTEA, - status INTEGER, - error TEXT, - - PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id), - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - ) - `) - return err - }} -} diff --git a/database/upgrades/2022-05-11-add-user-timezone.go b/database/upgrades/2022-05-11-add-user-timezone.go deleted file mode 100644 index 4420cb2..0000000 --- a/database/upgrades/2022-05-11-add-user-timezone.go +++ /dev/null @@ -1,12 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`) - return err - }} -} diff --git a/database/upgrades/2022-05-12-backfillqueue-dispatch-time.go b/database/upgrades/2022-05-12-backfillqueue-dispatch-time.go deleted file mode 100644 index 52530e6..0000000 --- a/database/upgrades/2022-05-12-backfillqueue-dispatch-time.go +++ /dev/null @@ -1,34 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error { - // First, add dispatch_time TIMESTAMP column - _, err := tx.Exec(` - ALTER TABLE backfill_queue - ADD COLUMN dispatch_time TIMESTAMP - `) - if err != nil { - return err - } - - // For all previous jobs, set dispatch time to the completed time. - _, err = tx.Exec(` - UPDATE backfill_queue - SET dispatch_time=completed_at - `) - if err != nil { - return err - } - - // Remove time_end from the backfill queue - _, err = tx.Exec(` - ALTER TABLE backfill_queue - DROP COLUMN time_end - `) - return err - }} -} diff --git a/database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go b/database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go deleted file mode 100644 index 0c7b325..0000000 --- a/database/upgrades/2022-05-12-history-sync-message-add-added-timestamp.go +++ /dev/null @@ -1,16 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error { - // Add the inserted time TIMESTAMP column to history_sync_message - _, err := tx.Exec(` - ALTER TABLE history_sync_message - ADD COLUMN inserted_time TIMESTAMP - `) - return err - }} -} diff --git a/database/upgrades/2022-05-16-room-backfill-state.go b/database/upgrades/2022-05-16-room-backfill-state.go deleted file mode 100644 index 976b129..0000000 --- a/database/upgrades/2022-05-16-room-backfill-state.go +++ /dev/null @@ -1,25 +0,0 @@ -package upgrades - -import ( - "database/sql" -) - -func init() { - upgrades[46] = upgrade{"Create the backfill state table", func(tx *sql.Tx, ctx context) error { - _, err := tx.Exec(` - CREATE TABLE backfill_state ( - user_mxid TEXT, - portal_jid TEXT, - portal_receiver TEXT, - processing_batch BOOLEAN, - backfill_complete BOOLEAN, - first_expected_ts INTEGER, - - PRIMARY KEY (user_mxid, portal_jid, portal_receiver), - FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE - ) - `) - return err - }} -} diff --git a/database/upgrades/45-backfillqueue-dispatch-time.sql b/database/upgrades/45-backfillqueue-dispatch-time.sql new file mode 100644 index 0000000..1c02b12 --- /dev/null +++ b/database/upgrades/45-backfillqueue-dispatch-time.sql @@ -0,0 +1,5 @@ +-- v45: Add dispatch time to backfill queue + +ALTER TABLE backfill_queue ADD COLUMN dispatch_time TIMESTAMP; +UPDATE backfill_queue SET dispatch_time=completed_at; +ALTER TABLE backfill_queue DROP COLUMN time_end; diff --git a/database/upgrades/46-history-sync-message-added-timestamp.sql b/database/upgrades/46-history-sync-message-added-timestamp.sql new file mode 100644 index 0000000..4cf2cfc --- /dev/null +++ b/database/upgrades/46-history-sync-message-added-timestamp.sql @@ -0,0 +1,3 @@ +-- v46: Add inserted time to history sync message + +ALTER TABLE history_sync_message ADD COLUMN inserted_time TIMESTAMP; diff --git a/database/upgrades/47-room-backfill-state.sql b/database/upgrades/47-room-backfill-state.sql new file mode 100644 index 0000000..40e052a --- /dev/null +++ b/database/upgrades/47-room-backfill-state.sql @@ -0,0 +1,13 @@ +-- v47: Add table for keeping track of backfill state + +CREATE TABLE backfill_state ( + user_mxid TEXT, + portal_jid TEXT, + portal_receiver TEXT, + processing_batch BOOLEAN, + backfill_complete BOOLEAN, + first_expected_ts TIMESTAMP, + PRIMARY KEY (user_mxid, portal_jid, portal_receiver), + FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE +); diff --git a/database/upgrades/48-crypto-store-handling-split.sql b/database/upgrades/48-crypto-store-handling-split.sql new file mode 100644 index 0000000..2ef9da1 --- /dev/null +++ b/database/upgrades/48-crypto-store-handling-split.sql @@ -0,0 +1,7 @@ +-- v48: Move crypto/state/whatsmeow store upgrade handling to separate systems +CREATE TABLE crypto_version (version INTEGER PRIMARY KEY); +INSERT INTO crypto_version VALUES (6); +CREATE TABLE whatsmeow_version (version INTEGER PRIMARY KEY); +INSERT INTO whatsmeow_version VALUES (1); +CREATE TABLE mx_version (version INTEGER PRIMARY KEY); +INSERT INTO mx_version VALUES (1); diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index eba9df8..b60b00a 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -1,180 +1,27 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package upgrades import ( "database/sql" + "embed" "errors" - "fmt" - "strings" - log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/util/dbutil" ) -type Dialect int +var Table dbutil.UpgradeTable -const ( - Postgres Dialect = iota - SQLite -) +//go:embed *.sql +var rawUpgrades embed.FS -func (dialect Dialect) String() string { - switch dialect { - case Postgres: - return "postgres" - case SQLite: - return "sqlite3" - default: - return "" - } -} - -type upgradeFunc func(*sql.Tx, context) error - -type context struct { - dialect Dialect - db *sql.DB - log log.Logger -} - -type upgrade struct { - message string - fn upgradeFunc -} - -const NumberOfUpgrades = 47 - -var upgrades [NumberOfUpgrades]upgrade - -var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version") -var ErrForeignTables = fmt.Errorf("the database contains foreign tables") -var ErrNotOwned = fmt.Errorf("the database is owned by") -var IgnoreForeignTables = false - -const databaseOwner = "mautrix-whatsapp" - -func GetVersion(db *sql.DB) (int, error) { - _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)") - if err != nil { - return -1, err - } - - version := 0 - err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return -1, err - } - return version, nil -} - -const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)" -const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)" - -func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) { - if dialect == SQLite { - _ = db.QueryRow(tableExistsSQLite, table).Scan(&exists) - } else if dialect == Postgres { - _ = db.QueryRow(tableExistsPostgres, table).Scan(&exists) - } - return -} - -const createOwnerTable = ` -CREATE TABLE IF NOT EXISTS database_owner ( - key INTEGER PRIMARY KEY DEFAULT 0, - owner TEXT NOT NULL -) -` - -func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error { - var owner string - if !IgnoreForeignTables { - if tableExists(dialect, db, "state_groups_state") { - return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables) - } else if tableExists(dialect, db, "goose_db_version") { - return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables) - } - } - if _, err := db.Exec(createOwnerTable); err != nil { - return fmt.Errorf("failed to ensure database owner table exists: %w", err) - } else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) { - _, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner) - if err != nil { - return fmt.Errorf("failed to insert database owner: %w", err) - } - } else if err != nil { - return fmt.Errorf("failed to check database owner: %w", err) - } else if owner != databaseOwner { - return fmt.Errorf("%w %s", ErrNotOwned, owner) - } - return nil -} - -func SetVersion(tx *sql.Tx, version int) error { - _, err := tx.Exec("DELETE FROM version") - if err != nil { - return err - } - _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version) - return err -} - -func execMany(tx *sql.Tx, queries ...string) error { - for _, query := range queries { - _, err := tx.Exec(query) - if err != nil { - return err - } - } - return nil -} - -func Run(log log.Logger, dialectName string, db *sql.DB) error { - var dialect Dialect - switch strings.ToLower(dialectName) { - case "postgres": - dialect = Postgres - case "sqlite3": - dialect = SQLite - default: - return fmt.Errorf("unknown dialect %s", dialectName) - } - - err := CheckDatabaseOwner(dialect, db) - if err != nil { - return err - } - - version, err := GetVersion(db) - if err != nil { - return err - } - - if version > NumberOfUpgrades { - return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades) - } - - log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades) - for i, upgradeItem := range upgrades[version:] { - if upgradeItem.fn == nil { - continue - } - log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message) - var tx *sql.Tx - tx, err = db.Begin() - if err != nil { - return err - } - err = upgradeItem.fn(tx, context{dialect, db, log}) - if err != nil { - return err - } - err = SetVersion(tx, version+i+1) - if err != nil { - return err - } - err = tx.Commit() - if err != nil { - return err - } - } - return nil +func init() { + Table.Register(-1, 43, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error { + return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version") + }) + Table.RegisterFS(rawUpgrades) } diff --git a/database/user.go b/database/user.go index 8a73dc7..8b77850 100644 --- a/database/user.go +++ b/database/user.go @@ -24,6 +24,7 @@ import ( log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" "go.mau.fi/whatsmeow/types" ) @@ -89,7 +90,7 @@ type User struct { inSpaceCacheLock sync.Mutex } -func (user *User) Scan(row Scannable) *User { +func (user *User) Scan(row dbutil.Scannable) *User { var username, timezone sql.NullString var device, agent sql.NullByte var phoneLastSeen, phoneLastPinged sql.NullInt64 diff --git a/disappear.go b/disappear.go index 77eeb5b..34736a3 100644 --- a/disappear.go +++ b/disappear.go @@ -50,9 +50,9 @@ func (portal *Portal) ScheduleDisappearing() { } } -func (bridge *Bridge) SleepAndDeleteUpcoming() { - for _, msg := range bridge.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) { - portal := bridge.GetPortalByMXID(msg.RoomID) +func (br *WABridge) SleepAndDeleteUpcoming() { + for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) { + portal := br.GetPortalByMXID(msg.RoomID) if portal == nil { msg.Delete() } else { diff --git a/example-config.yaml b/example-config.yaml index f8a70d8..c460893 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -43,14 +43,6 @@ appservice: max_conn_idle_time: null max_conn_lifetime: null - # Settings for provisioning API - provisioning: - # Prefix for the provisioning API paths. - prefix: /_matrix/provision - # Shared secret for authentication. If set to "generate", a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. - shared_secret: generate - # The unique ID of this appservice. id: whatsapp # Appservice bot details. @@ -317,6 +309,14 @@ bridge: # Verification by the bridge is not yet implemented. require_verification: true + # Settings for provisioning API + provisioning: + # Prefix for the provisioning API paths. + prefix: /_matrix/provision + # Shared secret for authentication. If set to "generate", a random secret will be generated, + # or if set to "disable", the provisioning API will be disabled. + shared_secret: generate + # Permissions for using the bridge. # Permitted values: # relay - Talk through the relaybot (if enabled), no access otherwise diff --git a/formatting.go b/formatting.go index 9ff9144..e99d186 100644 --- a/formatting.go +++ b/formatting.go @@ -37,7 +37,7 @@ var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```") const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids" type Formatter struct { - bridge *Bridge + bridge *WABridge matrixHTMLParser *format.HTMLParser @@ -46,7 +46,7 @@ type Formatter struct { waReplFuncText map[*regexp.Regexp]func(string) string } -func NewFormatter(bridge *Bridge) *Formatter { +func NewFormatter(bridge *WABridge) *Formatter { formatter := &Formatter{ bridge: bridge, matrixHTMLParser: &format.HTMLParser{ diff --git a/go.mod b/go.mod index 49e07ca..ccdac5b 100644 --- a/go.mod +++ b/go.mod @@ -14,10 +14,8 @@ require ( golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 golang.org/x/net v0.0.0-20220513224357-95641704303c google.golang.org/protobuf v1.28.0 - gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 - maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.3.2 - maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 + maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 ) require ( @@ -37,7 +35,8 @@ require ( golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect golang.org/x/text v0.3.7 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.0 // indirect + maunium.net/go/mauflag v1.0.0 // indirect ) // Exclude some things that cause go.sum to explode diff --git a/go.sum b/go.sum index 5335dbf..54db25f 100644 --- a/go.sum +++ b/go.sum @@ -99,14 +99,13 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc= -gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= -maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 h1:+KEF+nSuBfHWsfQRz92YP/DdSLbComLoXCXgcrH6WRU= -maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1/go.mod h1:K29EcHwsNg6r7fMfwvi0GHQ9o5wSjqB9+Q8RjCIQEjA= +maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 h1:7ZORg2h+lflc1HwjTKCXZnykauXD+wzbW+VDknbv6SU= +maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5/go.mod h1:oma8o6Y/5jcViBlDbX7tp1ajP2XP+b78h8twdI+zKI0= diff --git a/main.go b/main.go index d83342d..671ddaf 100644 --- a/main.go +++ b/main.go @@ -18,43 +18,26 @@ package main import ( _ "embed" - "errors" - "fmt" "net/http" "os" - "os/signal" "strconv" "strings" "sync" - "syscall" "time" - "google.golang.org/protobuf/proto" - "go.mau.fi/whatsmeow" waProto "go.mau.fi/whatsmeow/binary/proto" "go.mau.fi/whatsmeow/store" "go.mau.fi/whatsmeow/store/sqlstore" "go.mau.fi/whatsmeow/types" + "google.golang.org/protobuf/proto" - flag "maunium.net/go/mauflag" - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/util/configupgrade" "maunium.net/go/mautrix-whatsapp/config" "maunium.net/go/mautrix-whatsapp/database" - "maunium.net/go/mautrix-whatsapp/database/upgrades" -) - -// The name and repo URL of the bridge. -var ( - Name = "mautrix-whatsapp" - URL = "https://github.com/mautrix/whatsapp" ) // Information to find out exactly which commit the bridge was built from. @@ -65,120 +48,19 @@ var ( BuildTime = "unknown" ) -var ( - // Version is the version number of the bridge. Changed manually when making a release. - Version = "0.4.0" - // WAVersion is the version number exposed to WhatsApp. Filled in init() - WAVersion = "" - // VersionString is the bridge version, plus commit information. Filled in init() using the build-time values. - VersionString = "" -) - //go:embed example-config.yaml var ExampleConfig string -func init() { - if len(Tag) > 0 && Tag[0] == 'v' { - Tag = Tag[1:] - } - if Tag != Version { - suffix := "" - if !strings.HasSuffix(Version, "+dev") { - suffix = "+dev" - } - if len(Commit) > 8 { - Version = fmt.Sprintf("%s%s.%s", Version, suffix, Commit[:8]) - } else { - Version = fmt.Sprintf("%s%s.unknown", Version, suffix) - } - } - mautrix.DefaultUserAgent = fmt.Sprintf("mautrix-whatsapp/%s %s", Version, mautrix.DefaultUserAgent) - WAVersion = strings.FieldsFunc(Version, func(r rune) bool { return r == '-' || r == '+' })[0] - VersionString = fmt.Sprintf("%s %s (%s)", Name, Version, BuildTime) - - config.ExampleConfig = ExampleConfig -} - -var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() -var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() -var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() -var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() -var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() -var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() -var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() -var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool() -var wantHelp, _ = flag.MakeHelpFlag() - -func (bridge *Bridge) GenerateRegistration() { - if *dontSaveConfig { - // We need to save the generated as_token and hs_token in the config - _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") - os.Exit(5) - } - reg, err := bridge.Config.NewRegistration() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to generate registration:", err) - os.Exit(20) - } - - err = reg.Save(*registrationPath) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) - os.Exit(21) - } - - err = config.Mutate(*configPath, func(helper *configupgrade.Helper) { - helper.Set(configupgrade.Str, bridge.Config.AppService.ASToken, "appservice", "as_token") - helper.Set(configupgrade.Str, bridge.Config.AppService.HSToken, "appservice", "hs_token") - }) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) - os.Exit(22) - } - fmt.Println("Registration generated. Add the path to the registration to your Synapse config, restart it, then start the bridge.") - os.Exit(0) -} - -func (bridge *Bridge) MigrateDatabase() { - oldDB, err := database.New(config.DatabaseConfig{Type: flag.Arg(0), URI: flag.Arg(1)}, log.DefaultLogger) - if err != nil { - fmt.Println("Failed to open old database:", err) - os.Exit(30) - } - err = oldDB.Init() - if err != nil { - fmt.Println("Failed to upgrade old database:", err) - os.Exit(31) - } - - newDB, err := database.New(bridge.Config.AppService.Database, log.DefaultLogger) - if err != nil { - fmt.Println("Failed to open new database:", err) - os.Exit(32) - } - err = newDB.Init() - if err != nil { - fmt.Println("Failed to upgrade new database:", err) - os.Exit(33) - } - - database.Migrate(oldDB, newDB) -} - -type Bridge struct { - AS *appservice.AppService - EventProcessor *appservice.EventProcessor - MatrixHandler *MatrixHandler - Config *config.Config - DB *database.Database - Log log.Logger - StateStore *database.SQLStateStore - Provisioning *ProvisioningAPI - Bot *appservice.IntentAPI - Formatter *Formatter - Crypto Crypto - Metrics *MetricsHandler - WAContainer *sqlstore.Container +type WABridge struct { + bridge.Bridge + MatrixHandler *MatrixHandler + Config *config.Config + DB *database.Database + Provisioning *ProvisioningAPI + Formatter *Formatter + Metrics *MetricsHandler + WAContainer *sqlstore.Container + WAVersion string usersByMXID map[id.UserID]*User usersByUsername map[string]*User @@ -195,111 +77,32 @@ type Bridge struct { puppetsLock sync.Mutex } -type Crypto interface { - HandleMemberEvent(*event.Event) - Decrypt(*event.Event) (*event.Event, error) - Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error) - WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(id.RoomID) - Init() error - Start() - Stop() -} - -func (bridge *Bridge) ensureConnection() { - for { - versions, err := bridge.Bot.Versions() - if err != nil { - bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err) - time.Sleep(10 * time.Second) - continue - } - if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) { - bridge.Log.Warnfln("Server isn't advertising modern spec versions") - } - resp, err := bridge.Bot.Whoami() - if err != nil { - if errors.Is(err, mautrix.MUnknownToken) { - bridge.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - os.Exit(16) - } else if errors.Is(err, mautrix.MExclusive) { - bridge.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?") - os.Exit(16) - } - bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err) - time.Sleep(10 * time.Second) - } else if resp.UserID != bridge.Bot.UserID { - bridge.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, bridge.Bot.UserID) - os.Exit(17) - } else { - break - } - } -} - -func (bridge *Bridge) Init() { - var err error - - bridge.AS, err = bridge.Config.MakeAppService() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err) - os.Exit(11) - } - _, _ = bridge.AS.Init() - - bridge.Log = log.Create() - bridge.Config.Logging.Configure(bridge.Log) - log.DefaultLogger = bridge.Log.(*log.BasicLogger) - if len(bridge.Config.Logging.FileNameFormat) > 0 { - err = log.OpenFile() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err) - os.Exit(12) - } - } - bridge.AS.Log = log.Sub("Matrix") - bridge.Bot = bridge.AS.BotIntent() - bridge.Log.Infoln("Initializing", VersionString) - - bridge.Log.Debugln("Initializing database connection") - bridge.DB, err = database.New(bridge.Config.AppService.Database, bridge.Log) - if err != nil { - bridge.Log.Fatalln("Failed to initialize database connection:", err) - os.Exit(14) - } - - bridge.Log.Debugln("Initializing state store") - bridge.StateStore = database.NewSQLStateStore(bridge.DB) - bridge.AS.StateStore = bridge.StateStore - - Segment.log = bridge.Log.Sub("Segment") - Segment.key = bridge.Config.SegmentKey +func (br *WABridge) Init() { + Segment.log = br.Log.Sub("Segment") + Segment.key = br.Config.SegmentKey if Segment.IsEnabled() { Segment.log.Infoln("Segment metrics are enabled") } - bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil) - bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError + br.DB = database.New(br.Bridge.DB) + br.WAContainer = sqlstore.NewWithDB(br.DB.DB, br.DB.Dialect.String(), nil) + br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError - ss := bridge.Config.AppService.Provisioning.SharedSecret + ss := br.Config.Bridge.Provisioning.SharedSecret if len(ss) > 0 && ss != "disable" { - bridge.Provisioning = &ProvisioningAPI{bridge: bridge} + br.Provisioning = &ProvisioningAPI{bridge: br} } - bridge.Log.Debugln("Initializing Matrix event processor") - bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS) - bridge.Log.Debugln("Initializing Matrix event handler") - bridge.MatrixHandler = NewMatrixHandler(bridge) - bridge.Formatter = NewFormatter(bridge) - bridge.Crypto = NewCryptoHelper(bridge) - bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB) + br.Log.Debugln("Initializing Matrix event handler") + br.MatrixHandler = NewMatrixHandler(br) + br.Formatter = NewFormatter(br) + br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB) - store.BaseClientPayload.UserAgent.OsVersion = proto.String(WAVersion) - store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(WAVersion) - store.CompanionProps.Os = proto.String(bridge.Config.WhatsApp.OSName) - store.CompanionProps.RequireFullSync = proto.Bool(bridge.Config.Bridge.HistorySync.RequestFullSync) - versionParts := strings.Split(WAVersion, ".") + store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion) + store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(br.WAVersion) + store.CompanionProps.Os = proto.String(br.Config.WhatsApp.OSName) + store.CompanionProps.RequireFullSync = proto.Bool(br.Config.Bridge.HistorySync.RequestFullSync) + versionParts := strings.Split(br.WAVersion, ".") if len(versionParts) > 2 { primary, _ := strconv.Atoi(versionParts[0]) secondary, _ := strconv.Atoi(versionParts[1]) @@ -308,161 +111,107 @@ func (bridge *Bridge) Init() { store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary)) store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary)) } - platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(bridge.Config.WhatsApp.BrowserName)] + platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(br.Config.WhatsApp.BrowserName)] if ok { store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum() } } -func (bridge *Bridge) Start() { - bridge.Log.Debugln("Running database upgrades") - err := bridge.DB.Init() - if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) { - bridge.Log.Fatalln("Failed to initialize database:", err) - if errors.Is(err, upgrades.ErrForeignTables) { - bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error") - } else if errors.Is(err, upgrades.ErrNotOwned) { - bridge.Log.Infoln("Sharing the same database with different programs is not supported") - } else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) { - bridge.Log.Infoln("Downgrading the bridge is not supported") - } +func (br *WABridge) Start() { + err := br.WAContainer.Upgrade() + if err != nil { + br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err) os.Exit(15) } - bridge.Log.Debugln("Checking connection to homeserver") - bridge.ensureConnection() - if bridge.Crypto != nil { - err = bridge.Crypto.Init() - if err != nil { - bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err) - os.Exit(19) - } + if br.Provisioning != nil { + br.Log.Debugln("Initializing provisioning API") + br.Provisioning.Init() } - if bridge.Provisioning != nil { - bridge.Log.Debugln("Initializing provisioning API") - bridge.Provisioning.Init() - } - bridge.Log.Debugln("Starting application service HTTP server") - go bridge.AS.Start() - bridge.Log.Debugln("Starting event processor") - go bridge.EventProcessor.Start() - go bridge.CheckWhatsAppUpdate() - go bridge.UpdateBotProfile() - if bridge.Crypto != nil { - go bridge.Crypto.Start() - } - go bridge.StartUsers() - if bridge.Config.Metrics.Enabled { - go bridge.Metrics.Start() + go br.CheckWhatsAppUpdate() + go br.StartUsers() + if br.Config.Metrics.Enabled { + go br.Metrics.Start() } - if bridge.Config.Bridge.ResendBridgeInfo { - go bridge.ResendBridgeInfo() + if br.Config.Bridge.ResendBridgeInfo { + go br.ResendBridgeInfo() } - go bridge.Loop() - bridge.AS.Ready = true + go br.Loop() } -func (bridge *Bridge) CheckWhatsAppUpdate() { - bridge.Log.Debugfln("Checking for WhatsApp web update") +func (br *WABridge) CheckWhatsAppUpdate() { + br.Log.Debugfln("Checking for WhatsApp web update") resp, err := whatsmeow.CheckUpdate(http.DefaultClient) if err != nil { - bridge.Log.Warnfln("Failed to check for WhatsApp web update: %v", err) + br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err) return } if store.GetWAVersion() == resp.ParsedVersion { - bridge.Log.Debugfln("Bridge is using latest WhatsApp web protocol") + br.Log.Debugfln("Bridge is using latest WhatsApp web protocol") } else if store.GetWAVersion().LessThan(resp.ParsedVersion) { if resp.IsBelowHard || resp.IsBroken { - bridge.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) } else if resp.IsBelowSoft { - bridge.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) } else { - bridge.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) + br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion) } } else { - bridge.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol") + br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol") } } -func (bridge *Bridge) Loop() { +func (br *WABridge) Loop() { for { - bridge.SleepAndDeleteUpcoming() + br.SleepAndDeleteUpcoming() time.Sleep(1 * time.Hour) - bridge.WarnUsersAboutDisconnection() + br.WarnUsersAboutDisconnection() } } -func (bridge *Bridge) WarnUsersAboutDisconnection() { - bridge.usersLock.Lock() - for _, user := range bridge.usersByUsername { +func (br *WABridge) WarnUsersAboutDisconnection() { + br.usersLock.Lock() + for _, user := range br.usersByUsername { if user.IsConnected() && !user.PhoneRecentlySeen(true) { go user.sendPhoneOfflineWarning() } } - bridge.usersLock.Unlock() + br.usersLock.Unlock() } -func (bridge *Bridge) ResendBridgeInfo() { - if *dontSaveConfig { - bridge.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag") - } else { - err := config.Mutate(*configPath, func(helper *configupgrade.Helper) { - helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") - }) - if err != nil { - bridge.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err) - } - } - bridge.Log.Infoln("Re-sending bridge info state event to all portals") - for _, portal := range bridge.GetAllPortals() { - portal.UpdateBridgeInfo() - } - bridge.Log.Infoln("Finished re-sending bridge info state events") +func (br *WABridge) ResendBridgeInfo() { + // FIXME + //if *dontSaveConfig { + // br.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag") + //} else { + // err := config.Mutate(*configPath, func(helper *configupgrade.Helper) { + // helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") + // }) + // if err != nil { + // br.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err) + // } + //} + //br.Log.Infoln("Re-sending bridge info state event to all portals") + //for _, portal := range br.GetAllPortals() { + // portal.UpdateBridgeInfo() + //} + //br.Log.Infoln("Finished re-sending bridge info state events") } -func (bridge *Bridge) UpdateBotProfile() { - bridge.Log.Debugln("Updating bot profile") - botConfig := &bridge.Config.AppService.Bot - - var err error - var mxc id.ContentURI - if botConfig.Avatar == "remove" { - err = bridge.Bot.SetAvatarURL(mxc) - } else if len(botConfig.Avatar) > 0 { - mxc, err = id.ParseContentURI(botConfig.Avatar) - if err == nil { - err = bridge.Bot.SetAvatarURL(mxc) - } - botConfig.ParsedAvatar = mxc - } - if err != nil { - bridge.Log.Warnln("Failed to update bot avatar:", err) - } - - if botConfig.Displayname == "remove" { - err = bridge.Bot.SetDisplayName("") - } else if len(botConfig.Displayname) > 0 { - err = bridge.Bot.SetDisplayName(botConfig.Displayname) - } - if err != nil { - bridge.Log.Warnln("Failed to update bot displayname:", err) - } -} - -func (bridge *Bridge) StartUsers() { - bridge.Log.Debugln("Starting users") +func (br *WABridge) StartUsers() { + br.Log.Debugln("Starting users") foundAnySessions := false - for _, user := range bridge.GetAllUsers() { + for _, user := range br.GetAllUsers() { if !user.JID.IsEmpty() { foundAnySessions = true } go user.Connect() } if !foundAnySessions { - bridge.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil)) + br.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil)) } - bridge.Log.Debugln("Starting custom puppets") - for _, loopuppet := range bridge.GetAllPuppetsWithCustomMXID() { + br.Log.Debugln("Starting custom puppets") + for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() { go func(puppet *Puppet) { puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID) err := puppet.StartCustomMXID(true) @@ -473,80 +222,37 @@ func (bridge *Bridge) StartUsers() { } } -func (bridge *Bridge) Stop() { - if bridge.Crypto != nil { - bridge.Crypto.Stop() +func (br *WABridge) Stop() { + if br.Crypto != nil { + br.Crypto.Stop() } - bridge.AS.Stop() - bridge.Metrics.Stop() - bridge.EventProcessor.Stop() - for _, user := range bridge.usersByUsername { + br.AS.Stop() + br.Metrics.Stop() + br.EventProcessor.Stop() + for _, user := range br.usersByUsername { if user.Client == nil { continue } - bridge.Log.Debugln("Disconnecting", user.MXID) + br.Log.Debugln("Disconnecting", user.MXID) user.Client.Disconnect() close(user.historySyncs) } } -func (bridge *Bridge) Main() { - configData, upgraded, err := config.Upgrade(*configPath, !*dontSaveConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) - if configData == nil { - os.Exit(10) - } +func (br *WABridge) GetExampleConfig() string { + return ExampleConfig +} + +func (br *WABridge) GetConfigPtr() interface{} { + br.Config = &config.Config{ + BaseConfig: &br.Bridge.Config, } - - bridge.Config, err = config.Load(configData, upgraded) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) - os.Exit(10) - } - - if *generateRegistration { - bridge.GenerateRegistration() - return - } else if *migrateFrom { - bridge.MigrateDatabase() - return - } - - bridge.Init() - bridge.Log.Infoln("Bridge initialization complete, starting...") - bridge.Start() - bridge.Log.Infoln("Bridge started!") - - c := make(chan os.Signal) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - - bridge.Log.Infoln("Interrupt received, stopping...") - bridge.Stop() - bridge.Log.Infoln("Bridge stopped.") - os.Exit(0) + br.Config.BaseConfig.Bridge = &br.Config.Bridge + return br.Config } func main() { - flag.SetHelpTitles( - "mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.", - "mautrix-whatsapp [-h] [-c ] [-r ] [-g] [--migrate-db ]") - err := flag.Parse() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, err) - flag.PrintHelp() - os.Exit(1) - } else if *wantHelp { - flag.PrintHelp() - os.Exit(0) - } else if *version { - fmt.Println(VersionString) - return - } - upgrades.IgnoreForeignTables = *ignoreForeignTables - - (&Bridge{ + br := &WABridge{ usersByMXID: make(map[id.UserID]*User), usersByUsername: make(map[string]*User), spaceRooms: make(map[id.RoomID]*User), @@ -555,5 +261,24 @@ func main() { portalsByJID: make(map[database.PortalKey]*Portal), puppets: make(map[types.JID]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), - }).Main() + } + br.Bridge = bridge.Bridge{ + Name: "mautrix-whatsapp", + URL: "https://github.com/mautrix/whatsapp", + Description: "A Matrix-WhatsApp puppeting bridge.", + Version: "0.4.0", + ProtocolName: "WhatsApp", + + ConfigUpgrader: &configupgrade.StructUpgrader{ + SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade), + Blocks: config.SpacedBlocks, + Base: ExampleConfig, + }, + + Child: br, + } + br.InitVersion(Tag, Commit, BuildTime) + br.WAVersion = strings.FieldsFunc(br.Version, func(r rune) bool { return r == '-' || r == '+' })[0] + + br.Main() } diff --git a/matrix.go b/matrix.go index aaeb0c4..27f2e12 100644 --- a/matrix.go +++ b/matrix.go @@ -28,6 +28,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" @@ -36,13 +37,13 @@ import ( ) type MatrixHandler struct { - bridge *Bridge + bridge *WABridge as *appservice.AppService log maulogger.Logger cmd *CommandHandler } -func NewMatrixHandler(bridge *Bridge) *MatrixHandler { +func NewMatrixHandler(bridge *WABridge) *MatrixHandler { handler := &MatrixHandler{ bridge: bridge, as: bridge.AS, @@ -362,7 +363,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { decrypted, err := mx.bridge.Crypto.Decrypt(evt) decryptionRetryCount := 0 - if errors.Is(err, NoSessionFound) { + if errors.Is(err, bridge.NoSessionFound) { content := evt.Content.AsEncrypted() mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds())) mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount) diff --git a/no-crypto.go b/no-crypto.go deleted file mode 100644 index fa5f434..0000000 --- a/no-crypto.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build !cgo || nocrypto - -package main - -import ( - "errors" -) - -func NewCryptoHelper(bridge *Bridge) Crypto { - if !bridge.Config.Bridge.Encryption.Allow { - bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config") - } - bridge.Log.Debugln("Bridge built without end-to-bridge encryption") - return nil -} - -var NoSessionFound = errors.New("nil") diff --git a/portal.go b/portal.go index e8afd8d..8593294 100644 --- a/portal.go +++ b/portal.go @@ -45,6 +45,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" @@ -69,68 +70,72 @@ const PrivateChatTopic = "WhatsApp private chat" var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled") -func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal { - bridge.portalsLock.Lock() - defer bridge.portalsLock.Unlock() - portal, ok := bridge.portalsByMXID[mxid] +func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal { + br.portalsLock.Lock() + defer br.portalsLock.Unlock() + portal, ok := br.portalsByMXID[mxid] if !ok { - return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil) + return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil) } return portal } -func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal { - bridge.portalsLock.Lock() - defer bridge.portalsLock.Unlock() - portal, ok := bridge.portalsByJID[key] +func (br *WABridge) GetIPortalByMXID(mxid id.RoomID) bridge.Portal { + return br.GetPortalByMXID(mxid) +} + +func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal { + br.portalsLock.Lock() + defer br.portalsLock.Unlock() + portal, ok := br.portalsByJID[key] if !ok { - return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key) + return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key) } return portal } -func (bridge *Bridge) GetAllPortals() []*Portal { - return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll()) +func (br *WABridge) GetAllPortals() []*Portal { + return br.dbPortalsToPortals(br.DB.Portal.GetAll()) } -func (bridge *Bridge) GetAllPortalsForUser(userID id.UserID) []*Portal { - return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllForUser(userID)) +func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal { + return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID)) } -func (bridge *Bridge) GetAllPortalsByJID(jid types.JID) []*Portal { - return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid)) +func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal { + return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid)) } -func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { - bridge.portalsLock.Lock() - defer bridge.portalsLock.Unlock() +func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { + br.portalsLock.Lock() + defer br.portalsLock.Unlock() output := make([]*Portal, len(dbPortals)) for index, dbPortal := range dbPortals { if dbPortal == nil { continue } - portal, ok := bridge.portalsByJID[dbPortal.Key] + portal, ok := br.portalsByJID[dbPortal.Key] if !ok { - portal = bridge.loadDBPortal(dbPortal, nil) + portal = br.loadDBPortal(dbPortal, nil) } output[index] = portal } return output } -func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { +func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { if dbPortal == nil { if key == nil { return nil } - dbPortal = bridge.DB.Portal.New() + dbPortal = br.DB.Portal.New() dbPortal.Key = *key dbPortal.Insert() } - portal := bridge.NewPortal(dbPortal) - bridge.portalsByJID[portal.Key] = portal + portal := br.NewPortal(dbPortal) + br.portalsByJID[portal.Key] = portal if len(portal.MXID) > 0 { - bridge.portalsByMXID[portal.MXID] = portal + br.portalsByMXID[portal.MXID] = portal } return portal } @@ -139,14 +144,14 @@ func (portal *Portal) GetUsers() []*User { return nil } -func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal { +func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal { portal := &Portal{ - bridge: bridge, - log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)), + bridge: br, + log: br.Log.Sub(fmt.Sprintf("Portal/%s", key)), - messages: make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer), - matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer), - mediaRetries: make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer), + messages: make(chan PortalMessage, br.Config.Bridge.PortalMessageBuffer), + matrixMessages: make(chan PortalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), + mediaRetries: make(chan PortalMediaRetry, br.Config.Bridge.PortalMessageBuffer), mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta), } @@ -154,15 +159,15 @@ func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal { return portal } -func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal { - portal := bridge.newBlankPortal(key) - portal.Portal = bridge.DB.Portal.New() +func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal { + portal := br.newBlankPortal(key) + portal.Portal = br.DB.Portal.New() portal.Key = key return portal } -func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal { - portal := bridge.newBlankPortal(dbPortal.Key) +func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal { + portal := br.newBlankPortal(dbPortal.Key) portal.Portal = dbPortal return portal } @@ -203,7 +208,7 @@ type recentlyHandledWrapper struct { type Portal struct { *database.Portal - bridge *Bridge + bridge *WABridge log log.Logger roomCreateLock sync.Mutex @@ -229,6 +234,10 @@ type Portal struct { relayUser *User } +func (portal *Portal) IsEncrypted() bool { + return portal.Encrypted +} + func (portal *Portal) handleMessageLoopItem(msg PortalMessage) { if len(portal.MXID) == 0 { if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) { diff --git a/provisioning.go b/provisioning.go index e532079..16479e9 100644 --- a/provisioning.go +++ b/provisioning.go @@ -43,15 +43,15 @@ import ( ) type ProvisioningAPI struct { - bridge *Bridge + bridge *WABridge log log.Logger } func (prov *ProvisioningAPI) Init() { prov.log = prov.bridge.Log.Sub("Provisioning") - prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix) - r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter() + prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix) + r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter() r.Use(prov.AuthMiddleware) r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet) r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet) @@ -109,7 +109,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { } else if strings.HasPrefix(auth, "Bearer ") { auth = auth[len("Bearer "):] } - if auth != prov.bridge.Config.AppService.Provisioning.SharedSecret { + if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { jsonResponse(w, http.StatusForbidden, map[string]interface{}{ "error": "Invalid auth token", "errcode": "M_FORBIDDEN", diff --git a/puppet.go b/puppet.go index bacc39b..1d9fc97 100644 --- a/puppet.go +++ b/puppet.go @@ -39,11 +39,11 @@ import ( var userIDRegex *regexp.Regexp -func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) { +func (br *WABridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) { if userIDRegex == nil { userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$", - bridge.Config.Bridge.FormatUsername("([0-9]+)"), - bridge.Config.Homeserver.Domain)) + br.Config.Bridge.FormatUsername("([0-9]+)"), + br.Config.Homeserver.Domain)) } match := userIDRegex.FindStringSubmatch(string(mxid)) if len(match) == 2 { @@ -53,79 +53,79 @@ func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) { return } -func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet { - jid, ok := bridge.ParsePuppetMXID(mxid) +func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet { + jid, ok := br.ParsePuppetMXID(mxid) if !ok { return nil } - return bridge.GetPuppetByJID(jid) + return br.GetPuppetByJID(jid) } -func (bridge *Bridge) GetPuppetByJID(jid types.JID) *Puppet { +func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet { jid = jid.ToNonAD() if jid.Server == types.LegacyUserServer { jid.Server = types.DefaultUserServer } else if jid.Server != types.DefaultUserServer { return nil } - bridge.puppetsLock.Lock() - defer bridge.puppetsLock.Unlock() - puppet, ok := bridge.puppets[jid] + br.puppetsLock.Lock() + defer br.puppetsLock.Unlock() + puppet, ok := br.puppets[jid] if !ok { - dbPuppet := bridge.DB.Puppet.Get(jid) + dbPuppet := br.DB.Puppet.Get(jid) if dbPuppet == nil { - dbPuppet = bridge.DB.Puppet.New() + dbPuppet = br.DB.Puppet.New() dbPuppet.JID = jid dbPuppet.Insert() } - puppet = bridge.NewPuppet(dbPuppet) - bridge.puppets[puppet.JID] = puppet + puppet = br.NewPuppet(dbPuppet) + br.puppets[puppet.JID] = puppet if len(puppet.CustomMXID) > 0 { - bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet + br.puppetsByCustomMXID[puppet.CustomMXID] = puppet } } return puppet } -func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { - bridge.puppetsLock.Lock() - defer bridge.puppetsLock.Unlock() - puppet, ok := bridge.puppetsByCustomMXID[mxid] +func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { + br.puppetsLock.Lock() + defer br.puppetsLock.Unlock() + puppet, ok := br.puppetsByCustomMXID[mxid] if !ok { - dbPuppet := bridge.DB.Puppet.GetByCustomMXID(mxid) + dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid) if dbPuppet == nil { return nil } - puppet = bridge.NewPuppet(dbPuppet) - bridge.puppets[puppet.JID] = puppet - bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet + puppet = br.NewPuppet(dbPuppet) + br.puppets[puppet.JID] = puppet + br.puppetsByCustomMXID[puppet.CustomMXID] = puppet } return puppet } -func (bridge *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet { - return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAllWithCustomMXID()) +func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet { + return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID()) } -func (bridge *Bridge) GetAllPuppets() []*Puppet { - return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAll()) +func (br *WABridge) GetAllPuppets() []*Puppet { + return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll()) } -func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { - bridge.puppetsLock.Lock() - defer bridge.puppetsLock.Unlock() +func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet { + br.puppetsLock.Lock() + defer br.puppetsLock.Unlock() output := make([]*Puppet, len(dbPuppets)) for index, dbPuppet := range dbPuppets { if dbPuppet == nil { continue } - puppet, ok := bridge.puppets[dbPuppet.JID] + puppet, ok := br.puppets[dbPuppet.JID] if !ok { - puppet = bridge.NewPuppet(dbPuppet) - bridge.puppets[dbPuppet.JID] = puppet + puppet = br.NewPuppet(dbPuppet) + br.puppets[dbPuppet.JID] = puppet if len(dbPuppet.CustomMXID) > 0 { - bridge.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet + br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet } } output[index] = puppet @@ -133,26 +133,26 @@ func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet return output } -func (bridge *Bridge) FormatPuppetMXID(jid types.JID) id.UserID { +func (br *WABridge) FormatPuppetMXID(jid types.JID) id.UserID { return id.NewUserID( - bridge.Config.Bridge.FormatUsername(jid.User), - bridge.Config.Homeserver.Domain) + br.Config.Bridge.FormatUsername(jid.User), + br.Config.Homeserver.Domain) } -func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { +func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { return &Puppet{ Puppet: dbPuppet, - bridge: bridge, - log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), + bridge: br, + log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), - MXID: bridge.FormatPuppetMXID(dbPuppet.JID), + MXID: br.FormatPuppetMXID(dbPuppet.JID), } } type Puppet struct { *database.Puppet - bridge *Bridge + bridge *WABridge log log.Logger typingIn id.RoomID diff --git a/user.go b/user.go index 01fb813..1d5c111 100644 --- a/user.go +++ b/user.go @@ -35,6 +35,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" @@ -56,7 +57,7 @@ type User struct { Client *whatsmeow.Client Session *store.Device - bridge *Bridge + bridge *WABridge log log.Logger Admin bool @@ -84,38 +85,46 @@ type User struct { BackfillQueue *BackfillQueue } -func (bridge *Bridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User { - _, isPuppet := bridge.ParsePuppetMXID(userID) - if isPuppet || userID == bridge.Bot.UserID { +func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User { + _, isPuppet := br.ParsePuppetMXID(userID) + if isPuppet || userID == br.Bot.UserID { return nil } - bridge.usersLock.Lock() - defer bridge.usersLock.Unlock() - user, ok := bridge.usersByMXID[userID] + br.usersLock.Lock() + defer br.usersLock.Unlock() + user, ok := br.usersByMXID[userID] if !ok { userIDPtr := &userID if onlyIfExists { userIDPtr = nil } - return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), userIDPtr) + return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr) } return user } -func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User { - return bridge.getUserByMXID(userID, false) +func (br *WABridge) GetUserByMXID(userID id.UserID) *User { + return br.getUserByMXID(userID, false) } -func (bridge *Bridge) GetUserByMXIDIfExists(userID id.UserID) *User { - return bridge.getUserByMXID(userID, true) +func (br *WABridge) GetIUserByMXID(userID id.UserID) bridge.User { + return br.getUserByMXID(userID, false) } -func (bridge *Bridge) GetUserByJID(jid types.JID) *User { - bridge.usersLock.Lock() - defer bridge.usersLock.Unlock() - user, ok := bridge.usersByUsername[jid.User] +func (user *User) IsAdmin() bool { + return user.Admin +} + +func (br *WABridge) GetUserByMXIDIfExists(userID id.UserID) *User { + return br.getUserByMXID(userID, true) +} + +func (br *WABridge) GetUserByJID(jid types.JID) *User { + br.usersLock.Lock() + defer br.usersLock.Unlock() + user, ok := br.usersByUsername[jid.User] if !ok { - return bridge.loadDBUser(bridge.DB.User.GetByUsername(jid.User), nil) + return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil) } return user } @@ -137,35 +146,35 @@ func (user *User) removeFromJIDMap(state BridgeState) { user.sendBridgeState(state) } -func (bridge *Bridge) GetAllUsers() []*User { - bridge.usersLock.Lock() - defer bridge.usersLock.Unlock() - dbUsers := bridge.DB.User.GetAll() +func (br *WABridge) GetAllUsers() []*User { + br.usersLock.Lock() + defer br.usersLock.Unlock() + dbUsers := br.DB.User.GetAll() output := make([]*User, len(dbUsers)) for index, dbUser := range dbUsers { - user, ok := bridge.usersByMXID[dbUser.MXID] + user, ok := br.usersByMXID[dbUser.MXID] if !ok { - user = bridge.loadDBUser(dbUser, nil) + user = br.loadDBUser(dbUser, nil) } output[index] = user } return output } -func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { +func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { if dbUser == nil { if mxid == nil { return nil } - dbUser = bridge.DB.User.New() + dbUser = br.DB.User.New() dbUser.MXID = *mxid dbUser.Insert() } - user := bridge.NewUser(dbUser) - bridge.usersByMXID[user.MXID] = user + user := br.NewUser(dbUser) + br.usersByMXID[user.MXID] = user if !user.JID.IsEmpty() { var err error - user.Session, err = bridge.WAContainer.GetDevice(user.JID) + user.Session, err = br.WAContainer.GetDevice(user.JID) if err != nil { user.log.Errorfln("Failed to load user's whatsapp session: %v", err) } else if user.Session == nil { @@ -174,20 +183,20 @@ func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User { user.Update() } else { user.Session.Log = &waLogger{user.log.Sub("Session")} - bridge.usersByUsername[user.JID.User] = user + br.usersByUsername[user.JID.User] = user } } if len(user.ManagementRoom) > 0 { - bridge.managementRooms[user.ManagementRoom] = user + br.managementRooms[user.ManagementRoom] = user } return user } -func (bridge *Bridge) NewUser(dbUser *database.User) *User { +func (br *WABridge) NewUser(dbUser *database.User) *User { user := &User{ User: dbUser, - bridge: bridge, - log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)), + bridge: br, + log: br.Log.Sub("User").Sub(string(dbUser.MXID)), historySyncs: make(chan *events.HistorySync, 32), lastPresence: types.PresenceUnavailable,