Move a bunch of stuff to mautrix-go

See d578d1a610

Database upgrades from before v0.4.0 were squashed, users must update
to at least v0.4.0 before updating beyond this commit.
This commit is contained in:
Tulir Asokan 2022-05-22 01:06:30 +03:00
parent 42a4839a4e
commit a948ea0146
83 changed files with 627 additions and 2838 deletions

View File

@ -114,18 +114,18 @@ func (pong *BridgeState) shouldDeduplicate(newPong *BridgeState) bool {
return pong.Timestamp+int64(pong.TTL/5) > time.Now().Unix()
}
func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
func (br *WABridge) sendBridgeState(ctx context.Context, state *BridgeState) error {
var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(&state); err != nil {
return fmt.Errorf("failed to encode bridge state JSON: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, bridge.Config.Homeserver.StatusEndpoint, &body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, br.Config.Homeserver.StatusEndpoint, &body)
if err != nil {
return fmt.Errorf("failed to prepare request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+bridge.Config.AppService.ASToken)
req.Header.Set("Authorization", "Bearer "+br.Config.AppService.ASToken)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
@ -143,17 +143,17 @@ func (bridge *Bridge) sendBridgeState(ctx context.Context, state *BridgeState) e
return nil
}
func (bridge *Bridge) sendGlobalBridgeState(state BridgeState) {
if len(bridge.Config.Homeserver.StatusEndpoint) == 0 {
func (br *WABridge) sendGlobalBridgeState(state BridgeState) {
if len(br.Config.Homeserver.StatusEndpoint) == 0 {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := bridge.sendBridgeState(ctx, &state); err != nil {
bridge.Log.Warnln("Failed to update global bridge state:", err)
if err := br.sendBridgeState(ctx, &state); err != nil {
br.Log.Warnln("Failed to update global bridge state:", err)
} else {
bridge.Log.Debugfln("Sent new global bridge state %+v", state)
br.Log.Debugfln("Sent new global bridge state %+v", state)
}
}

View File

@ -39,6 +39,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@ -47,12 +48,12 @@ import (
)
type CommandHandler struct {
bridge *Bridge
bridge *WABridge
log maulogger.Logger
}
// NewCommandHandler creates a CommandHandler
func NewCommandHandler(bridge *Bridge) *CommandHandler {
func NewCommandHandler(bridge *WABridge) *CommandHandler {
return &CommandHandler{
bridge: bridge,
log: bridge.Log.Sub("Command handler"),
@ -62,7 +63,7 @@ func NewCommandHandler(bridge *Bridge) *CommandHandler {
// CommandEvent stores all data which might be used to handle commands
type CommandEvent struct {
Bot *appservice.IntentAPI
Bridge *Bridge
Bridge *WABridge
Portal *Portal
Handler *CommandHandler
RoomID id.RoomID
@ -251,13 +252,7 @@ func (handler *CommandHandler) CommandDevTest(_ *CommandEvent) {
const cmdVersionHelp = `version - View the bridge version`
func (handler *CommandHandler) CommandVersion(ce *CommandEvent) {
linkifiedVersion := fmt.Sprintf("v%s", Version)
if Tag == Version {
linkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", Version, URL, Tag)
} else if len(Commit) > 8 {
linkifiedVersion = strings.Replace(linkifiedVersion, Commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", Commit[:8], URL, Commit), 1)
}
ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", Name, URL, linkifiedVersion, BuildTime))
ce.Reply(fmt.Sprintf("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, BuildTime))
}
const cmdInviteLinkHelp = `invite-link [--reset] - Get an invite link to the current group chat, optionally regenerating the link and revoking the old link.`
@ -331,7 +326,7 @@ func (handler *CommandHandler) CommandJoin(ce *CommandEvent) {
ce.Reply("Successfully joined group `%s`, the portal should be created momentarily", jid)
}
func tryDecryptEvent(crypto Crypto, evt *event.Event) (json.RawMessage, error) {
func tryDecryptEvent(crypto bridge.Crypto, evt *event.Event) (json.RawMessage, error) {
var data json.RawMessage
if evt.Type != event.EventEncrypted {
data = evt.Content.VeryRaw
@ -903,7 +898,7 @@ func matchesQuery(str string, query string) bool {
return strings.Contains(strings.ToLower(str), query)
}
func formatContacts(bridge *Bridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
func formatContacts(bridge *WABridge, input map[types.JID]types.ContactInfo, query string) (result []string) {
hasQuery := len(query) > 0
for jid, contact := range input {
if len(contact.FullName) == 0 {

View File

@ -24,6 +24,7 @@ import (
"go.mau.fi/whatsmeow/types"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -118,23 +119,23 @@ type BridgeConfig struct {
AdditionalHelp string `yaml:"additional_help"`
} `yaml:"management_room_text"`
Encryption struct {
Allow bool `yaml:"allow"`
Default bool `yaml:"default"`
Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
KeySharing struct {
Allow bool `yaml:"allow"`
RequireCrossSigning bool `yaml:"require_cross_signing"`
RequireVerification bool `yaml:"require_verification"`
} `yaml:"key_sharing"`
} `yaml:"encryption"`
Provisioning struct {
Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
} `yaml:"provisioning"`
Permissions PermissionConfig `yaml:"permissions"`
Relay RelaybotConfig `yaml:"relay"`
usernameTemplate *template.Template `yaml:"-"`
displaynameTemplate *template.Template `yaml:"-"`
ParsedUsernameTemplate *template.Template `yaml:"-"`
displaynameTemplate *template.Template `yaml:"-"`
}
func (bc BridgeConfig) GetEncryptionConfig() bridgeconfig.EncryptionConfig {
return bc.Encryption
}
type umBridgeConfig BridgeConfig
@ -145,7 +146,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
return err
}
bc.usernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
bc.ParsedUsernameTemplate, err = template.New("username").Parse(bc.UsernameTemplate)
if err != nil {
return err
} else if !strings.Contains(bc.FormatUsername("1234567890"), "1234567890") {
@ -206,7 +207,7 @@ func (bc BridgeConfig) FormatDisplayname(jid types.JID, contact types.ContactInf
func (bc BridgeConfig) FormatUsername(username string) string {
var buf strings.Builder
_ = bc.usernameTemplate.Execute(&buf, username)
_ = bc.ParsedUsernameTemplate.Execute(&buf, username)
return buf.String()
}

View File

@ -17,52 +17,12 @@
package config
import (
"fmt"
"gopkg.in/yaml.v3"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/id"
)
var ExampleConfig string
type Config struct {
Homeserver struct {
Address string `yaml:"address"`
Domain string `yaml:"domain"`
Asmux bool `yaml:"asmux"`
StatusEndpoint string `yaml:"status_endpoint"`
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
AsyncMedia bool `yaml:"async_media"`
} `yaml:"homeserver"`
AppService struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
Port uint16 `yaml:"port"`
Database DatabaseConfig `yaml:"database"`
Provisioning struct {
Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"`
} `yaml:"provisioning"`
ID string `yaml:"id"`
Bot struct {
Username string `yaml:"username"`
Displayname string `yaml:"displayname"`
Avatar string `yaml:"avatar"`
ParsedAvatar id.ContentURI `yaml:"-"`
} `yaml:"bot"`
EphemeralEvents bool `yaml:"ephemeral_events"`
ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
} `yaml:"appservice"`
*bridgeconfig.BaseConfig `yaml:",inline"`
SegmentKey string `yaml:"segment_key"`
@ -77,8 +37,6 @@ type Config struct {
} `yaml:"whatsapp"`
Bridge BridgeConfig `yaml:"bridge"`
Logging appservice.LogConfig `yaml:"logging"`
}
func (config *Config) CanAutoDoublePuppet(userID id.UserID) bool {
@ -98,44 +56,3 @@ func (config *Config) CanDoublePuppetBackfill(userID id.UserID) bool {
}
return true
}
func Load(data []byte, upgraded bool) (*Config, error) {
var config = &Config{}
if !upgraded {
// Fallback: if config upgrading failed, load example config for base values
err := yaml.Unmarshal([]byte(ExampleConfig), config)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal example config: %w", err)
}
}
err := yaml.Unmarshal(data, config)
if err != nil {
return nil, err
}
return config, err
}
func (config *Config) MakeAppService() (*appservice.AppService, error) {
as := appservice.Create()
as.HomeserverDomain = config.Homeserver.Domain
as.HomeserverURL = config.Homeserver.Address
as.Host.Hostname = config.AppService.Hostname
as.Host.Port = config.AppService.Port
as.MessageSendCheckpointEndpoint = config.Homeserver.MessageSendCheckpointEndpoint
as.DefaultHTTPRetries = 4
var err error
as.Registration, err = config.GetRegistration()
return as, err
}
type DatabaseConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}

View File

@ -1,82 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package config
import (
"fmt"
"regexp"
"strings"
"maunium.net/go/mautrix/appservice"
)
func (config *Config) NewRegistration() (*appservice.Registration, error) {
registration := appservice.CreateRegistration()
err := config.copyToRegistration(registration)
if err != nil {
return nil, err
}
config.AppService.ASToken = registration.AppToken
config.AppService.HSToken = registration.ServerToken
// Workaround for https://github.com/matrix-org/synapse/pull/5758
registration.SenderLocalpart = appservice.RandomString(32)
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
regexp.QuoteMeta(config.AppService.Bot.Username),
regexp.QuoteMeta(config.Homeserver.Domain)))
registration.Namespaces.RegisterUserIDs(botRegex, true)
return registration, nil
}
func (config *Config) GetRegistration() (*appservice.Registration, error) {
registration := appservice.CreateRegistration()
err := config.copyToRegistration(registration)
if err != nil {
return nil, err
}
registration.AppToken = config.AppService.ASToken
registration.ServerToken = config.AppService.HSToken
return registration, nil
}
func (config *Config) copyToRegistration(registration *appservice.Registration) error {
registration.ID = config.AppService.ID
registration.URL = config.AppService.Address
falseVal := false
registration.RateLimited = &falseVal
registration.SenderLocalpart = config.AppService.Bot.Username
registration.EphemeralEvents = config.AppService.EphemeralEvents
usernamePlaceholder := appservice.RandomString(16)
usernameTemplate := fmt.Sprintf("@%s:%s",
config.Bridge.FormatUsername(usernamePlaceholder),
config.Homeserver.Domain)
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, "[0-9]+", 1)
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
userIDRegex, err := regexp.Compile(usernameTemplate)
if err != nil {
return err
}
registration.Namespaces.RegisterUserIDs(userIDRegex, true)
return nil
}

View File

@ -20,50 +20,12 @@ import (
"strings"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge/bridgeconfig"
up "maunium.net/go/mautrix/util/configupgrade"
)
type waUpgrader struct{}
func (wau waUpgrader) GetBase() string {
return ExampleConfig
}
func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "homeserver", "address")
helper.Copy(up.Str, "homeserver", "domain")
helper.Copy(up.Bool, "homeserver", "asmux")
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
helper.Copy(up.Bool, "homeserver", "async_media")
helper.Copy(up.Str, "appservice", "address")
helper.Copy(up.Str, "appservice", "hostname")
helper.Copy(up.Int, "appservice", "port")
helper.Copy(up.Str, "appservice", "database", "type")
helper.Copy(up.Str, "appservice", "database", "uri")
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok && strings.HasSuffix(prefix, "/v1") {
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "appservice", "provisioning", "prefix")
} else {
helper.Copy(up.Str, "appservice", "provisioning", "prefix")
}
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := appservice.RandomString(64)
helper.Set(up.Str, sharedSecret, "appservice", "provisioning", "shared_secret")
} else {
helper.Copy(up.Str, "appservice", "provisioning", "shared_secret")
}
helper.Copy(up.Str, "appservice", "id")
helper.Copy(up.Str, "appservice", "bot", "username")
helper.Copy(up.Str, "appservice", "bot", "displayname")
helper.Copy(up.Str, "appservice", "bot", "avatar")
helper.Copy(up.Bool, "appservice", "ephemeral_events")
helper.Copy(up.Str, "appservice", "as_token")
helper.Copy(up.Str, "appservice", "hs_token")
func DoUpgrade(helper *up.Helper) {
bridgeconfig.Upgrader.DoUpgrade(helper)
helper.Copy(up.Str|up.Null, "segment_key")
@ -134,46 +96,41 @@ func (wau waUpgrader) DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "allow")
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_cross_signing")
helper.Copy(up.Bool, "bridge", "encryption", "key_sharing", "require_verification")
if prefix, ok := helper.Get(up.Str, "appservice", "provisioning", "prefix"); ok {
helper.Set(up.Str, strings.TrimSuffix(prefix, "/v1"), "bridge", "provisioning", "prefix")
} else {
helper.Copy(up.Str, "bridge", "provisioning", "prefix")
}
if secret, ok := helper.Get(up.Str, "appservice", "provisioning", "shared_secret"); ok && secret != "generate" {
helper.Set(up.Str, secret, "bridge", "provisioning", "shared_secret")
} else if secret, ok = helper.Get(up.Str, "bridge", "provisioning", "shared_secret"); !ok || secret == "generate" {
sharedSecret := appservice.RandomString(64)
helper.Set(up.Str, sharedSecret, "bridge", "provisioning", "shared_secret")
} else {
helper.Copy(up.Str, "bridge", "provisioning", "shared_secret")
}
helper.Copy(up.Map, "bridge", "permissions")
helper.Copy(up.Bool, "bridge", "relay", "enabled")
helper.Copy(up.Bool, "bridge", "relay", "admin_only")
helper.Copy(up.Map, "bridge", "relay", "message_formats")
helper.Copy(up.Str, "logging", "directory")
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
helper.Copy(up.Int, "logging", "file_mode")
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
helper.Copy(up.Str, "logging", "print_level")
}
func (wau waUpgrader) SpacedBlocks() [][]string {
return [][]string{
{"homeserver", "asmux"},
{"appservice"},
{"appservice", "hostname"},
{"appservice", "database"},
{"appservice", "provisioning"},
{"appservice", "id"},
{"appservice", "as_token"},
{"segment_key"},
{"metrics"},
{"whatsapp"},
{"bridge"},
{"bridge", "command_prefix"},
{"bridge", "management_room_text"},
{"bridge", "encryption"},
{"bridge", "permissions"},
{"bridge", "relay"},
{"logging"},
}
}
func Mutate(path string, mutate func(helper *up.Helper)) error {
_, _, err := up.Do(path, true, waUpgrader{}, up.SimpleUpgrader(mutate))
return err
}
func Upgrade(path string, save bool) ([]byte, bool, error) {
return up.Do(path, save, waUpgrader{})
var SpacedBlocks = [][]string{
{"homeserver", "asmux"},
{"appservice"},
{"appservice", "hostname"},
{"appservice", "database"},
{"appservice", "id"},
{"appservice", "as_token"},
{"segment_key"},
{"metrics"},
{"whatsapp"},
{"bridge"},
{"bridge", "command_prefix"},
{"bridge", "management_room_text"},
{"bridge", "encryption"},
{"bridge", "provisioning"},
{"bridge", "permissions"},
{"bridge", "relay"},
{"logging"},
}

327
crypto.go
View File

@ -1,327 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//go:build cgo && !nocrypto
package main
import (
"fmt"
"runtime/debug"
"time"
"github.com/lib/pq"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database"
)
var NoSessionFound = crypto.NoSessionFound
var levelTrace = maulogger.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
type CryptoHelper struct {
bridge *Bridge
client *mautrix.Client
mach *crypto.OlmMachine
store *database.SQLCryptoStore
log maulogger.Logger
baseLog maulogger.Logger
}
func init() {
crypto.PostgresArrayWrapper = pq.Array
}
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
return nil
}
baseLog := bridge.Log.Sub("Crypto")
return &CryptoHelper{
bridge: bridge,
log: baseLog.Sub("Helper"),
baseLog: baseLog,
}
}
func (helper *CryptoHelper) Init() error {
helper.log.Debugln("Initializing end-to-bridge encryption...")
helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.bridge.AS.BotMXID(),
fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain))
var err error
helper.client, err = helper.loginBot()
if err != nil {
return err
}
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
logger := &cryptoLogger{helper.baseLog}
stateStore := &cryptoStateStore{helper.bridge}
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
helper.mach.AllowKeyShare = helper.allowKeyShare
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = &cryptoClientStore{helper.store}
return helper.mach.Load()
}
func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection {
cfg := helper.bridge.Config.Bridge.Encryption.KeySharing
if !cfg.Allow {
return &crypto.KeyShareRejectNoResponse
} else if device.Trust == crypto.TrustStateBlacklisted {
return &crypto.KeyShareRejectBlacklisted
} else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification {
portal := helper.bridge.GetPortalByMXID(info.RoomID)
if portal == nil {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"}
}
user := helper.bridge.GetUserByMXID(device.UserID)
// FIXME reimplement IsInPortal
if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ {
helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID)
return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"}
}
helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID)
return nil
} else {
return &crypto.KeyShareRejectUnverified
}
}
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
}
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, "", "")
if err != nil {
return nil, fmt.Errorf("failed to initialize client: %w", err)
}
client.Logger = helper.baseLog.Sub("Bot")
client.Client = helper.bridge.AS.HTTPClient
client.DefaultHTTPRetries = helper.bridge.AS.DefaultHTTPRetries
flows, err := client.GetLoginFlows()
if err != nil {
return nil, fmt.Errorf("failed to get supported login flows: %w", err)
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
return nil, fmt.Errorf("homeserver does not support appservice login")
}
// We set the API token to the AS token here to authenticate the appservice login
// It'll get overridden after the login
client.AccessToken = helper.bridge.AS.Registration.AppToken
resp, err := client.Login(&mautrix.ReqLogin{
Type: mautrix.AuthTypeAppservice,
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID())},
DeviceID: deviceID,
InitialDeviceDisplayName: "WhatsApp Bridge",
StoreCredentials: true,
})
if err != nil {
return nil, fmt.Errorf("failed to log in as bridge bot: %w", err)
}
helper.store.DeviceID = resp.DeviceID
return client, nil
}
func (helper *CryptoHelper) Start() {
helper.log.Debugln("Starting syncer for receiving to-device messages")
err := helper.client.Sync()
if err != nil {
helper.log.Errorln("Fatal error syncing:", err)
} else {
helper.log.Infoln("Bridge bot to-device syncer stopped without error")
}
}
func (helper *CryptoHelper) Stop() {
helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync")
helper.client.StopSync()
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return nil, err
}
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
users, err := helper.store.GetRoomMembers(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get room member list: %w", err)
}
err = helper.mach.ShareGroupSession(roomID, users)
if err != nil {
return nil, fmt.Errorf("failed to share group session: %w", err)
}
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content)
if err != nil {
return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return encrypted, nil
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
if err != nil {
helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err)
} else {
helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID)
}
}
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
if err != nil {
helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err)
}
}
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
helper.mach.HandleMemberEvent(evt)
}
type cryptoSyncer struct {
*crypto.OlmMachine
}
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
done := make(chan struct{})
go func() {
defer func() {
if err := recover(); err != nil {
syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack())
}
done <- struct{}{}
}()
syncer.Log.Trace("Starting sync response handling (%s)", since)
syncer.ProcessSyncResponse(resp, since)
syncer.Log.Trace("Successfully handled sync response (%s)", since)
}()
select {
case <-done:
case <-time.After(30 * time.Second):
syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since)
}
return nil
}
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
return 10 * time.Second, nil
}
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
everything := []event.Type{{Type: "*"}}
return &mautrix.Filter{
Presence: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
Room: mautrix.RoomFilter{
IncludeLeave: false,
Ephemeral: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
State: mautrix.FilterPart{NotTypes: everything},
Timeline: mautrix.FilterPart{NotTypes: everything},
},
}
}
type cryptoLogger struct {
int maulogger.Logger
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}
type cryptoClientStore struct {
int *database.SQLCryptoStore
}
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
c.int.PutNextBatch(nextBatchToken)
}
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
return c.int.GetNextBatch()
}
var _ mautrix.Storer = (*cryptoClientStore)(nil)
type cryptoStateStore struct {
bridge *Bridge
}
var _ crypto.StateStore = (*cryptoStateStore)(nil)
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
portal := c.bridge.GetPortalByMXID(id)
if portal != nil {
return portal.Encrypted
}
return false
}
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
return c.bridge.StateStore.FindSharedRooms(id)
}
func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
// TODO implement
return nil
}

View File

@ -75,8 +75,8 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)},
Password: hex.EncodeToString(mac.Sum(nil)),
DeviceID: "WhatsApp Bridge",
InitialDeviceDisplayName: "WhatsApp Bridge",
DeviceID: "WhatsApp bridge",
InitialDeviceDisplayName: "WhatsApp bridge",
})
if err != nil {
return "", err
@ -84,22 +84,22 @@ func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
return resp.AccessToken, nil
}
func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
func (br *WABridge) newDoublePuppetClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) {
_, homeserver, err := mxid.Parse()
if err != nil {
return nil, err
}
homeserverURL, found := bridge.Config.Bridge.DoublePuppetServerMap[homeserver]
homeserverURL, found := br.Config.Bridge.DoublePuppetServerMap[homeserver]
if !found {
if homeserver == bridge.AS.HomeserverDomain {
homeserverURL = bridge.AS.HomeserverURL
} else if bridge.Config.Bridge.DoublePuppetAllowDiscovery {
if homeserver == br.AS.HomeserverDomain {
homeserverURL = br.AS.HomeserverURL
} else if br.Config.Bridge.DoublePuppetAllowDiscovery {
resp, err := mautrix.DiscoverClientAPI(homeserver)
if err != nil {
return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err)
}
homeserverURL = resp.Homeserver.BaseURL
bridge.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
br.Log.Debugfln("Discovered URL %s for %s to enable double puppeting for %s", homeserverURL, homeserver, mxid)
} else {
return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver)
}
@ -108,9 +108,9 @@ func (bridge *Bridge) newDoublePuppetClient(mxid id.UserID, accessToken string)
if err != nil {
return nil, err
}
client.Logger = bridge.AS.Log.Sub(mxid.String())
client.Client = bridge.AS.HTTPClient
client.DefaultHTTPRetries = bridge.AS.DefaultHTTPRetries
client.Logger = br.AS.Log.Sub(mxid.String())
client.Client = br.AS.HTTPClient
client.DefaultHTTPRetries = br.AS.DefaultHTTPRetries
return client, nil
}

View File

@ -26,7 +26,9 @@ import (
"time"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type BackfillType int
@ -165,7 +167,7 @@ func (b *Backfill) String() string {
)
}
func (b *Backfill) Scan(row Scannable) *Backfill {
func (b *Backfill) Scan(row dbutil.Scannable) *Backfill {
err := row.Scan(&b.QueueID, &b.UserID, &b.BackfillType, &b.Priority, &b.Portal.JID, &b.Portal.Receiver, &b.TimeStart, &b.MaxBatchEvents, &b.MaxTotalEvents, &b.BatchDelay)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
@ -256,7 +258,7 @@ type BackfillState struct {
FirstExpectedTimestamp uint64
}
func (b *BackfillState) Scan(row Scannable) *BackfillState {
func (b *BackfillState) Scan(row dbutil.Scannable) *BackfillState {
err := row.Scan(&b.UserID, &b.Portal.JID, &b.Portal.Receiver, &b.ProcessingBatch, &b.BackfillComplete, &b.FirstExpectedTimestamp)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {

View File

@ -1,18 +1,8 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
// Copyright (c) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build cgo && !nocrypto
@ -21,8 +11,6 @@ package database
import (
"database/sql"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id"
)
@ -37,11 +25,9 @@ var _ crypto.Store = (*SQLCryptoStore)(nil)
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
return &SQLCryptoStore{
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
[]byte("maunium.net/go/mautrix-whatsapp"),
&cryptoLogger{db.log.Sub("CryptoStore")}),
UserID: userID,
GhostIDFormat: ghostIDFormat,
SQLCryptoStore: crypto.NewSQLCryptoStore(db.Database, "", "", []byte("maunium.net/go/mautrix-whatsapp")),
UserID: userID,
GhostIDFormat: ghostIDFormat,
}
}
@ -76,30 +62,3 @@ func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.User
}
return
}
// TODO merge this with the one in the parent package
type cryptoLogger struct {
int log.Logger
}
var levelTrace = log.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}

View File

@ -17,21 +17,17 @@
package database
import (
"database/sql"
"errors"
"fmt"
"net"
"time"
"github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"maunium.net/go/mautrix-whatsapp/config"
"maunium.net/go/mautrix-whatsapp/database/upgrades"
"maunium.net/go/mautrix/util/dbutil"
)
func init() {
@ -39,9 +35,7 @@ func init() {
}
type Database struct {
*sql.DB
log log.Logger
dialect string
*dbutil.Database
User *UserQuery
Portal *PortalQuery
@ -55,79 +49,46 @@ type Database struct {
MediaBackfillRequest *MediaBackfillRequestQuery
}
func New(cfg config.DatabaseConfig, baseLog log.Logger) (*Database, error) {
conn, err := sql.Open(cfg.Type, cfg.URI)
if err != nil {
return nil, err
}
db := &Database{
DB: conn,
log: baseLog.Sub("Database"),
dialect: cfg.Type,
}
func New(baseDB *dbutil.Database) *Database {
db := &Database{Database: baseDB}
db.UpgradeTable = upgrades.Table
db.User = &UserQuery{
db: db,
log: db.log.Sub("User"),
log: db.Log.Sub("User"),
}
db.Portal = &PortalQuery{
db: db,
log: db.log.Sub("Portal"),
log: db.Log.Sub("Portal"),
}
db.Puppet = &PuppetQuery{
db: db,
log: db.log.Sub("Puppet"),
log: db.Log.Sub("Puppet"),
}
db.Message = &MessageQuery{
db: db,
log: db.log.Sub("Message"),
log: db.Log.Sub("Message"),
}
db.Reaction = &ReactionQuery{
db: db,
log: db.log.Sub("Reaction"),
log: db.Log.Sub("Reaction"),
}
db.DisappearingMessage = &DisappearingMessageQuery{
db: db,
log: db.log.Sub("DisappearingMessage"),
log: db.Log.Sub("DisappearingMessage"),
}
db.Backfill = &BackfillQuery{
db: db,
log: db.log.Sub("Backfill"),
log: db.Log.Sub("Backfill"),
}
db.HistorySync = &HistorySyncQuery{
db: db,
log: db.log.Sub("HistorySync"),
log: db.Log.Sub("HistorySync"),
}
db.MediaBackfillRequest = &MediaBackfillRequestQuery{
db: db,
log: db.log.Sub("MediaBackfillRequest"),
log: db.Log.Sub("MediaBackfillRequest"),
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
db.SetConnMaxIdleTime(maxIdleTimeDuration)
}
if len(cfg.ConnMaxLifetime) > 0 {
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
if err != nil {
return nil, fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
db.SetConnMaxLifetime(maxLifetimeDuration)
}
return db, nil
}
func (db *Database) Init() error {
return upgrades.Run(db.log.Sub("Upgrade"), db.dialect, db.DB)
}
type Scannable interface {
Scan(...interface{}) error
return db
}
func isRetryableError(err error) bool {
@ -145,7 +106,7 @@ func isRetryableError(err error) bool {
}
func (db *Database) HandleSignalStoreError(device *store.Device, action string, attemptIndex int, err error) (retry bool) {
if db.dialect != "sqlite" && isRetryableError(err) {
if db.Dialect != dbutil.SQLite && isRetryableError(err) {
sleepTime := time.Duration(attemptIndex*2) * time.Second
device.Log.Warnf("Failed to %s (attempt #%d): %v - retrying in %v", action, attemptIndex+1, err, sleepTime)
time.Sleep(sleepTime)

View File

@ -24,6 +24,7 @@ import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type DisappearingMessageQuery struct {
@ -94,7 +95,7 @@ type DisappearingMessage struct {
ExpireAt time.Time
}
func (msg *DisappearingMessage) Scan(row Scannable) *DisappearingMessage {
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
var expireIn int64
var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)

View File

@ -27,7 +27,9 @@ import (
_ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type HistorySyncQuery struct {
@ -139,7 +141,7 @@ func (hsc *HistorySyncConversation) Upsert() {
}
}
func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation {
func (hsc *HistorySyncConversation) Scan(row dbutil.Scannable) *HistorySyncConversation {
err := row.Scan(
&hsc.UserID,
&hsc.ConversationID,
@ -166,7 +168,7 @@ func (hsc *HistorySyncConversation) Scan(row Scannable) *HistorySyncConversation
func (hsq *HistorySyncQuery) GetNMostRecentConversations(userID id.UserID, n int) (conversations []*HistorySyncConversation) {
nPtr := &n
// Negative limit on SQLite means unlimited, but Postgres prefers a NULL limit.
if n < 0 && hsq.db.dialect == "postgres" {
if n < 0 && hsq.db.Dialect == dbutil.Postgres {
nPtr = nil
}
rows, err := hsq.db.Query(getNMostRecentConversations, userID, nPtr)

View File

@ -22,7 +22,9 @@ import (
_ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type MediaBackfillRequestStatus int
@ -100,7 +102,7 @@ func (mbr *MediaBackfillRequest) Upsert() {
}
}
func (mbr *MediaBackfillRequest) Scan(row Scannable) *MediaBackfillRequest {
func (mbr *MediaBackfillRequest) Scan(row dbutil.Scannable) *MediaBackfillRequest {
err := row.Scan(&mbr.UserID, &mbr.PortalKey.JID, &mbr.PortalKey.Receiver, &mbr.EventID, &mbr.MediaKey, &mbr.Status, &mbr.Error)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {

View File

@ -25,6 +25,7 @@ import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
@ -163,7 +164,7 @@ func (msg *Message) IsFakeJID() bool {
return strings.HasPrefix(msg.JID, "FAKE::") || msg.JID == string(msg.MXID)
}
func (msg *Message) Scan(row Scannable) *Message {
func (msg *Message) Scan(row dbutil.Scannable) *Message {
var ts int64
err := row.Scan(&msg.Chat.JID, &msg.Chat.Receiver, &msg.JID, &msg.MXID, &msg.Sender, &ts, &msg.Sent, &msg.Type, &msg.Error, &msg.BroadcastListJID)
if err != nil {

View File

@ -22,6 +22,7 @@ import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
@ -152,7 +153,7 @@ type Portal struct {
ExpirationTime uint32
}
func (portal *Portal) Scan(row Scannable) *Portal {
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID sql.NullString
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
if err != nil {

View File

@ -20,7 +20,9 @@ import (
"database/sql"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
@ -97,7 +99,7 @@ type Puppet struct {
EnableReceipts bool
}
func (puppet *Puppet) Scan(row Scannable) *Puppet {
func (puppet *Puppet) Scan(row dbutil.Scannable) *Puppet {
var displayname, avatar, avatarURL, customMXID, accessToken, nextBatch sql.NullString
var quality sql.NullInt64
var enablePresence, enableReceipts sql.NullBool

View File

@ -23,6 +23,7 @@ import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
@ -85,7 +86,7 @@ type Reaction struct {
JID types.MessageID
}
func (reaction *Reaction) Scan(row Scannable) *Reaction {
func (reaction *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&reaction.Chat.JID, &reaction.Chat.Receiver, &reaction.TargetJID, &reaction.Sender, &reaction.MXID, &reaction.JID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {

View File

@ -1,282 +0,0 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"encoding/json"
"errors"
"sync"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type SQLStateStore struct {
*appservice.TypingStateStore
db *Database
log log.Logger
Typing map[id.RoomID]map[id.UserID]int64
typingLock sync.RWMutex
}
var _ appservice.StateStore = (*SQLStateStore)(nil)
func NewSQLStateStore(db *Database) *SQLStateStore {
return &SQLStateStore{
TypingStateStore: appservice.NewTypingStateStore(),
db: db,
log: db.log.Sub("StateStore"),
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
var isRegistered bool
err := store.db.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.log.Warnfln("Failed to scan registration existence for %s: %v", userID, err)
}
return isRegistered
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.log.Warnfln("Failed to mark %s as registered: %v", userID, err)
}
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
if err != nil {
return members
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.db.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
}
return membership
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
}
return member
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
var member event.MemberEventContent
err := store.db.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
}
return &member, err == nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
rows, err := store.db.Query(`
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
WHERE user_id=$1 AND portal.encrypted=true
`, userID)
if err != nil {
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
return
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.log.Warnfln("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.db.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.db.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.log.Warnfln("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.db.Exec(`
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.log.Warnfln("Failed to store power levels of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.db.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power levels of %s: %v", roomID, err)
}
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.log.Errorfln("Failed to parse power levels of %s: %v", roomID, err)
return nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
if store.db.dialect == "postgres" {
var powerLevel int
err := store.db.
QueryRow(`
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level of %s in %s: %v", userID, roomID, err)
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
if store.db.dialect == "postgres" {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var powerLevel int
err := store.db.
QueryRow(`
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
if store.db.dialect == "postgres" {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var hasPower bool
err := store.db.
QueryRow(`SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.log.Errorfln("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
}
return hasPower
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View File

@ -0,0 +1,181 @@
-- v0 -> v48: Latest revision
CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
username TEXT UNIQUE,
agent SMALLINT,
device SMALLINT,
management_room TEXT,
space_room TEXT,
phone_last_seen BIGINT,
phone_last_pinged BIGINT,
timezone TEXT
);
CREATE TABLE portal (
jid TEXT,
receiver TEXT,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT,
encrypted BOOLEAN NOT NULL DEFAULT false,
first_event_id TEXT,
next_batch_id TEXT,
relay_user_id TEXT,
expiration_time BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (jid, receiver)
);
CREATE TABLE puppet (
username TEXT PRIMARY KEY,
displayname TEXT,
name_quality SMALLINT,
avatar TEXT,
avatar_url TEXT,
custom_mxid TEXT,
access_token TEXT,
next_batch TEXT,
enable_presence BOOLEAN NOT NULL DEFAULT true,
enable_receipts BOOLEAN NOT NULL DEFAULT true
);
-- only: postgres
CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found');
CREATE TABLE message (
chat_jid TEXT,
chat_receiver TEXT,
jid TEXT,
mxid TEXT UNIQUE,
sender TEXT,
timestamp BIGINT,
sent BOOLEAN,
error error_type,
type TEXT,
broadcast_list_jid TEXT,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
);
CREATE TABLE reaction (
chat_jid TEXT,
chat_receiver TEXT,
target_jid TEXT,
sender TEXT,
mxid TEXT NOT NULL,
jid TEXT NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
FOREIGN KEY (chat_jid, chat_receiver, target_jid) REFERENCES message(chat_jid, chat_receiver, jid)
ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE disappearing_message (
room_id TEXT,
event_id TEXT,
expire_in BIGINT NOT NULL,
expire_at BIGINT,
PRIMARY KEY (room_id, event_id)
);
CREATE TABLE user_portal (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_read_ts BIGINT NOT NULL DEFAULT 0,
in_space BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE backfill_queue (
queue_id INTEGER PRIMARY KEY
-- only: postgres
GENERATED ALWAYS AS IDENTITY
,
user_mxid TEXT,
type INTEGER NOT NULL,
priority INTEGER NOT NULL,
portal_jid TEXT,
portal_receiver TEXT,
time_start TIMESTAMP,
dispatch_time TIMESTAMP,
completed_at TIMESTAMP,
batch_delay INTEGER,
max_batch_events INTEGER NOT NULL,
max_total_events INTEGER,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts TIMESTAMP,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
);
CREATE TABLE media_backfill_requests (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
event_id TEXT,
media_key bytea,
status INTEGER,
error TEXT,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE history_sync_conversation (
user_mxid TEXT,
conversation_id TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_message_timestamp TIMESTAMP,
archived BOOLEAN,
pinned INTEGER,
mute_end_time TIMESTAMP,
disappearing_mode INTEGER,
end_of_history_transfer_type INTEGER,
ephemeral_Expiration INTEGER,
marked_as_unread BOOLEAN,
unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON UPDATE CASCADE ON DELETE CASCADE
);
CREATE TABLE history_sync_message (
user_mxid TEXT,
conversation_id TEXT,
message_id TEXT,
timestamp TIMESTAMP,
data bytea,
inserted_time TIMESTAMP,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON UPDATE CASCADE ON DELETE CASCADE,
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
);

View File

@ -1,67 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(255),
receiver VARCHAR(255),
mxid VARCHAR(255) UNIQUE,
name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL,
PRIMARY KEY (jid, receiver)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(255) PRIMARY KEY,
avatar VARCHAR(255),
displayname VARCHAR(255),
name_quality SMALLINT
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255),
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key bytea,
mac_key bytea
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content bytea NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,15 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[2] = upgrade{"Add timestamp column to messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec("ALTER TABLE message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0")
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,15 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[3] = upgrade{"Add last_connection column to users", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN last_connection BIGINT NOT NULL DEFAULT 0`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,23 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[5] = upgrade{"Add columns to store custom puppet info", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN custom_mxid VARCHAR(255)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN access_token VARCHAR(1023)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE puppet ADD COLUMN next_batch VARCHAR(255)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[6] = upgrade{"Add user-portal mapping table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE user_portal (
user_jid VARCHAR(255),
portal_jid VARCHAR(255),
portal_receiver VARCHAR(255),
PRIMARY KEY (user_jid, portal_jid, portal_receiver),
FOREIGN KEY (user_jid) REFERENCES "user"(jid) ON DELETE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
return err
}}
}

View File

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[7] = upgrade{"Add columns to store avatar MXC URIs", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN avatar_url VARCHAR(255)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[8] = upgrade{"Add columns to store portal in filtering community meta", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_community BOOLEAN NOT NULL DEFAULT FALSE`)
return err
}}
}

View File

@ -1,39 +0,0 @@
package upgrades
import (
"database/sql"
"strings"
)
func init() {
userProfileTable := `CREATE TABLE mx_user_profile (
room_id VARCHAR(255),
user_id VARCHAR(255),
membership VARCHAR(15) NOT NULL,
PRIMARY KEY (room_id, user_id)
)`
roomStateTable := `CREATE TABLE mx_room_state (
room_id VARCHAR(255) PRIMARY KEY,
power_levels TEXT
)`
registrationsTable := `CREATE TABLE mx_registrations (
user_id VARCHAR(255) PRIMARY KEY
)`
upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == Postgres {
roomStateTable = strings.Replace(roomStateTable, "TEXT", "JSONB", 1)
}
if _, err := tx.Exec(userProfileTable); err != nil {
return err
} else if _, err = tx.Exec(roomStateTable); err != nil {
return err
} else if _, err = tx.Exec(registrationsTable); err != nil {
return err
}
return nil
}}
}

View File

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[10] = upgrade{"Add columns to store full member info in state store", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN displayname TEXT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE mx_user_profile ADD COLUMN avatar_url VARCHAR(255)`)
return err
}}
}

View File

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[11] = upgrade{"Adjust the length of column topic in portal", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway.
return nil
}
_, err := tx.Exec(`ALTER TABLE portal ALTER COLUMN topic TYPE VARCHAR(512)`)
return err
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View File

@ -1,73 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_account (
device_id VARCHAR(255) PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id VARCHAR(255) NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
user_id VARCHAR(255) PRIMARY KEY
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_device (
user_id VARCHAR(255),
device_id VARCHAR(255),
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (user_id, device_id)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains bytea NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,25 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[15] = upgrade{"Add enable_presence column for puppets", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_presence BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[16] = upgrade{"Add account_id to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[1](tx, c.dialect.String())
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[17] = upgrade{"Add enable_receipts column for puppets", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE puppet ADD COLUMN enable_receipts BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[18] = upgrade{"Add megolm withheld data to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[2](tx, c.dialect.String())
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[19] = upgrade{"Add cross-signing keys to crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[3](tx, c.dialect.String())
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[20] = upgrade{"Add sent column for messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN sent BOOLEAN NOT NULL DEFAULT true`)
return err
}}
}

View File

@ -1,44 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[21] = upgrade{"Remove message content from local database", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
_, err := tx.Exec("ALTER TABLE message RENAME TO old_message")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid TEXT,
chat_receiver TEXT,
jid TEXT,
mxid TEXT NOT NULL UNIQUE,
sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
sent BOOLEAN NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`)
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO message SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, sent FROM old_message")
return err
} else {
_, err := tx.Exec(`ALTER TABLE message DROP COLUMN content`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN timestamp DROP DEFAULT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE message ALTER COLUMN sent DROP DEFAULT`)
return err
}
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[23] = upgrade{"Replace VARCHAR(255) with TEXT in the crypto database", func(tx *sql.Tx, ctx context) error {
return sql_store_upgrade.Upgrades[4](tx, ctx.dialect.String())
}}
}

View File

@ -1,48 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[22] = upgrade{"Replace VARCHAR(255) with TEXT in the database", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't enforce varchar sizes anyway
return nil
}
return execMany(tx,
`ALTER TABLE message ALTER COLUMN chat_jid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN chat_receiver TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE message ALTER COLUMN sender TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN receiver TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN name TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN topic TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN avatar TYPE TEXT`,
`ALTER TABLE portal ALTER COLUMN avatar_url TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN avatar TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN displayname TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN custom_mxid TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN access_token TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN next_batch TYPE TEXT`,
`ALTER TABLE puppet ALTER COLUMN avatar_url TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN mxid TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN jid TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN management_room TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN client_id TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN client_token TYPE TEXT`,
`ALTER TABLE "user" ALTER COLUMN server_token TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN user_jid TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN portal_jid TYPE TEXT`,
`ALTER TABLE user_portal ALTER COLUMN portal_receiver TYPE TEXT`,
)
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"go.mau.fi/whatsmeow/store/sqlstore"
)
func init() {
upgrades[24] = upgrade{"Add whatsmeow state store", func(tx *sql.Tx, ctx context) error {
return sqlstore.Upgrades[0](tx, sqlstore.NewWithDB(ctx.db, ctx.dialect.String(), nil))
}}
}

View File

@ -1,93 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[25] = upgrade{"Update things for multidevice", func(tx *sql.Tx, ctx context) error {
// This is probably not necessary
_, err := tx.Exec("DROP TABLE user_portal")
if err != nil {
return err
}
// Remove invalid puppet rows
_, err = tx.Exec("DELETE FROM puppet WHERE jid LIKE '%@g.us' OR jid LIKE '%@broadcast'")
if err != nil {
return err
}
// Remove the suffix from puppets since they'll all have the same suffix
_, err = tx.Exec("UPDATE puppet SET jid=REPLACE(jid, '@s.whatsapp.net', '')")
if err != nil {
return err
}
// Rename column to correctly represent the new content
_, err = tx.Exec("ALTER TABLE puppet RENAME COLUMN jid TO username")
if err != nil {
return err
}
if ctx.dialect == SQLite {
// Message content was removed from the main message table earlier, but the backup table still exists for SQLite
_, err = tx.Exec("DROP TABLE IF EXISTS old_message")
_, err = tx.Exec(`ALTER TABLE "user" RENAME TO old_user`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
username TEXT UNIQUE,
agent SMALLINT,
device SMALLINT,
management_room TEXT
)`)
if err != nil {
return err
}
// No need to copy auth data, users need to relogin anyway
_, err = tx.Exec(`INSERT INTO "user" (mxid, management_room) SELECT mxid, management_room FROM old_user`)
if err != nil {
return err
}
_, err = tx.Exec("DROP TABLE old_user")
if err != nil {
return err
}
} else {
// The jid column never actually contained the full JID, so let's rename it.
_, err = tx.Exec(`ALTER TABLE "user" RENAME COLUMN jid TO username`)
if err != nil {
return err
}
// The auth data is now in the whatsmeow_device table.
for _, column := range []string{"last_connection", "client_id", "client_token", "server_token", "enc_key", "mac_key"} {
_, err = tx.Exec(`ALTER TABLE "user" DROP COLUMN ` + column)
if err != nil {
return err
}
}
// The whatsmeow_device table is keyed by the full JID, so we need to store the other parts of the JID here too.
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN agent SMALLINT`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE "user" ADD COLUMN device SMALLINT`)
if err != nil {
return err
}
// Clear all usernames, the users need to relogin anyway.
_, err = tx.Exec(`UPDATE "user" SET username=null`)
if err != nil {
return err
}
}
return nil
}}
}

View File

@ -1,19 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[26] = upgrade{"Add columns to store infinite backfill pointers for portals", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN first_event_id TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE portal ADD COLUMN next_batch_id TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[27] = upgrade{"Add marker for WhatsApp decryption errors in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN decryption_error BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[28] = upgrade{"Add relay user field to portal table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN relay_user_id TEXT`)
return err
}}
}

View File

@ -1,22 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[29] = upgrade{"Replace VARCHAR(255) with TEXT in the Matrix state store", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == SQLite {
// SQLite doesn't enforce varchar sizes anyway
return nil
}
return execMany(tx,
`ALTER TABLE mx_registrations ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE mx_room_state ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN room_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN user_id TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT`,
`ALTER TABLE mx_user_profile ALTER COLUMN avatar_url TYPE TEXT`,
)
}}
}

View File

@ -1,22 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[30] = upgrade{"Store last read message timestamp in database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE user_portal (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_read_ts BIGINT NOT NULL DEFAULT 0,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
)`)
return err
}}
}

View File

@ -1,13 +0,0 @@
package upgrades
import (
"database/sql"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
func init() {
upgrades[31] = upgrade{"Split last_used into last_encrypted and last_decrypted in crypto store", func(tx *sql.Tx, c context) error {
return sql_store_upgrade.Upgrades[5](tx, c.dialect.String())
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[32] = upgrade{"Store source broadcast list in message table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN broadcast_list_jid TEXT`)
return err
}}
}

View File

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[33] = upgrade{"Add personal filtering space info to user tables", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN space_room TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
_, err = tx.Exec(`ALTER TABLE user_portal ADD COLUMN in_space BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View File

@ -1,20 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[34] = upgrade{"Add support for disappearing messages", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN expiration_time BIGINT NOT NULL DEFAULT 0 CHECK (expiration_time >= 0 AND expiration_time < 4294967296)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE disappearing_message (
room_id TEXT,
event_id TEXT,
expire_in BIGINT NOT NULL,
expire_at BIGINT,
PRIMARY KEY (room_id, event_id)
)`)
return err
}}
}

View File

@ -1,10 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[35] = upgrade{"Store approximate last seen timestamp of the main device", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_seen BIGINT`)
return err
}}
}

View File

@ -1,30 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[36] = upgrade{"Store message error type as string", func(tx *sql.Tx, ctx context) error {
if ctx.dialect == Postgres {
_, err := tx.Exec("CREATE TYPE error_type AS ENUM ('', 'decryption_failed', 'media_not_found')")
if err != nil {
return err
}
}
_, err := tx.Exec("ALTER TABLE message ADD COLUMN error error_type NOT NULL DEFAULT ''")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE message SET error='decryption_failed' WHERE decryption_error=true")
if err != nil {
return err
}
if ctx.dialect == Postgres {
// TODO do this on sqlite at some point
_, err = tx.Exec("ALTER TABLE message DROP COLUMN decryption_error")
if err != nil {
return err
}
}
return nil
}}
}

View File

@ -1,10 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[37] = upgrade{"Store timestamp for previous phone ping", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN phone_last_pinged BIGINT`)
return err
}}
}

View File

@ -1,39 +0,0 @@
package upgrades
import "database/sql"
func init() {
upgrades[38] = upgrade{"Add support for reactions", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE message ADD COLUMN type TEXT NOT NULL DEFAULT 'message'`)
if err != nil {
return err
}
if ctx.dialect == Postgres {
_, err = tx.Exec("ALTER TABLE message ALTER COLUMN type DROP DEFAULT")
if err != nil {
return err
}
}
_, err = tx.Exec("UPDATE message SET type='' WHERE error='decryption_failed'")
if err != nil {
return err
}
_, err = tx.Exec("UPDATE message SET type='fake' WHERE jid LIKE 'FAKE::%' OR mxid LIKE 'net.maunium.whatsapp.fake::%' OR jid=mxid")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE reaction (
chat_jid TEXT,
chat_receiver TEXT,
target_jid TEXT,
sender TEXT,
mxid TEXT NOT NULL,
jid TEXT NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, target_jid, sender),
CONSTRAINT target_message_fkey FOREIGN KEY (chat_jid, chat_receiver, target_jid)
REFERENCES message(chat_jid, chat_receiver, jid)
ON DELETE CASCADE ON UPDATE CASCADE
)`)
return err
}}
}

View File

@ -1,45 +0,0 @@
package upgrades
import (
"database/sql"
"fmt"
)
func init() {
upgrades[39] = upgrade{"Add backfill queue", func(tx *sql.Tx, ctx context) error {
// The queue_id needs to auto-increment every insertion. For SQLite,
// INTEGER PRIMARY KEY is an alias for the ROWID, so it will
// auto-increment. See https://sqlite.org/lang_createtable.html#rowid
// For Postgres, we need to add GENERATED ALWAYS AS IDENTITY for the
// same functionality.
queueIDColumnTypeModifier := ""
if ctx.dialect == Postgres {
queueIDColumnTypeModifier = "GENERATED ALWAYS AS IDENTITY"
}
_, err := tx.Exec(fmt.Sprintf(`
CREATE TABLE backfill_queue (
queue_id INTEGER PRIMARY KEY %s,
user_mxid TEXT,
type INTEGER NOT NULL,
priority INTEGER NOT NULL,
portal_jid TEXT,
portal_receiver TEXT,
time_start TIMESTAMP,
time_end TIMESTAMP,
max_batch_events INTEGER NOT NULL,
max_total_events INTEGER,
batch_delay INTEGER,
completed_at TIMESTAMP,
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`, queueIDColumnTypeModifier))
if err != nil {
return err
}
return err
}}
}

View File

@ -1,52 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[40] = upgrade{"Store history syncs for later backfills", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE history_sync_conversation (
user_mxid TEXT,
conversation_id TEXT,
portal_jid TEXT,
portal_receiver TEXT,
last_message_timestamp TIMESTAMP,
archived BOOLEAN,
pinned INTEGER,
mute_end_time TIMESTAMP,
disappearing_mode INTEGER,
end_of_history_transfer_type INTEGER,
ephemeral_expiration INTEGER,
marked_as_unread BOOLEAN,
unread_count INTEGER,
PRIMARY KEY (user_mxid, conversation_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
)
`)
if err != nil {
return err
}
_, err = tx.Exec(`
CREATE TABLE history_sync_message (
user_mxid TEXT,
conversation_id TEXT,
message_id TEXT,
timestamp TIMESTAMP,
data BYTEA,
PRIMARY KEY (user_mxid, conversation_id, message_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (user_mxid, conversation_id) REFERENCES history_sync_conversation(user_mxid, conversation_id) ON DELETE CASCADE
)
`)
if err != nil {
return err
}
return nil
}}
}

View File

@ -1,20 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[41] = upgrade{"Update backfill queue tables to be sortable by priority", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
UPDATE backfill_queue
SET type=CASE
WHEN type=1 THEN 200
WHEN type=2 THEN 300
ELSE type
END
WHERE type=1 OR type=2
`)
return err
}}
}

View File

@ -1,26 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[42] = upgrade{"Add table of media to request from the user's phone", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE media_backfill_requests (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
event_id TEXT,
media_key BYTEA,
status INTEGER,
error TEXT,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver, event_id),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`)
return err
}}
}

View File

@ -1,12 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[43] = upgrade{"Add timezone column to user table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE "user" ADD COLUMN timezone TEXT`)
return err
}}
}

View File

@ -1,34 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[44] = upgrade{"Add dispatch time to backfill queue", func(tx *sql.Tx, ctx context) error {
// First, add dispatch_time TIMESTAMP column
_, err := tx.Exec(`
ALTER TABLE backfill_queue
ADD COLUMN dispatch_time TIMESTAMP
`)
if err != nil {
return err
}
// For all previous jobs, set dispatch time to the completed time.
_, err = tx.Exec(`
UPDATE backfill_queue
SET dispatch_time=completed_at
`)
if err != nil {
return err
}
// Remove time_end from the backfill queue
_, err = tx.Exec(`
ALTER TABLE backfill_queue
DROP COLUMN time_end
`)
return err
}}
}

View File

@ -1,16 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[45] = upgrade{"Add inserted time to history sync message", func(tx *sql.Tx, ctx context) error {
// Add the inserted time TIMESTAMP column to history_sync_message
_, err := tx.Exec(`
ALTER TABLE history_sync_message
ADD COLUMN inserted_time TIMESTAMP
`)
return err
}}
}

View File

@ -1,25 +0,0 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[46] = upgrade{"Create the backfill state table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts INTEGER,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user"(mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)
`)
return err
}}
}

View File

@ -0,0 +1,5 @@
-- v45: Add dispatch time to backfill queue
ALTER TABLE backfill_queue ADD COLUMN dispatch_time TIMESTAMP;
UPDATE backfill_queue SET dispatch_time=completed_at;
ALTER TABLE backfill_queue DROP COLUMN time_end;

View File

@ -0,0 +1,3 @@
-- v46: Add inserted time to history sync message
ALTER TABLE history_sync_message ADD COLUMN inserted_time TIMESTAMP;

View File

@ -0,0 +1,13 @@
-- v47: Add table for keeping track of backfill state
CREATE TABLE backfill_state (
user_mxid TEXT,
portal_jid TEXT,
portal_receiver TEXT,
processing_batch BOOLEAN,
backfill_complete BOOLEAN,
first_expected_ts TIMESTAMP,
PRIMARY KEY (user_mxid, portal_jid, portal_receiver),
FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (portal_jid, portal_receiver) REFERENCES portal (jid, receiver) ON DELETE CASCADE
);

View File

@ -0,0 +1,7 @@
-- v48: Move crypto/state/whatsmeow store upgrade handling to separate systems
CREATE TABLE crypto_version (version INTEGER PRIMARY KEY);
INSERT INTO crypto_version VALUES (6);
CREATE TABLE whatsmeow_version (version INTEGER PRIMARY KEY);
INSERT INTO whatsmeow_version VALUES (1);
CREATE TABLE mx_version (version INTEGER PRIMARY KEY);
INSERT INTO mx_version VALUES (1);

View File

@ -1,180 +1,27 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package upgrades
import (
"database/sql"
"embed"
"errors"
"fmt"
"strings"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/util/dbutil"
)
type Dialect int
var Table dbutil.UpgradeTable
const (
Postgres Dialect = iota
SQLite
)
//go:embed *.sql
var rawUpgrades embed.FS
func (dialect Dialect) String() string {
switch dialect {
case Postgres:
return "postgres"
case SQLite:
return "sqlite3"
default:
return ""
}
}
type upgradeFunc func(*sql.Tx, context) error
type context struct {
dialect Dialect
db *sql.DB
log log.Logger
}
type upgrade struct {
message string
fn upgradeFunc
}
const NumberOfUpgrades = 47
var upgrades [NumberOfUpgrades]upgrade
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
var ErrNotOwned = fmt.Errorf("the database is owned by")
var IgnoreForeignTables = false
const databaseOwner = "mautrix-whatsapp"
func GetVersion(db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)")
if err != nil {
return -1, err
}
version := 0
err = db.QueryRow("SELECT version FROM version LIMIT 1").Scan(&version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return -1, err
}
return version, nil
}
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)"
func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) {
if dialect == SQLite {
_ = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
} else if dialect == Postgres {
_ = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
}
return
}
const createOwnerTable = `
CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)
`
func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error {
var owner string
if !IgnoreForeignTables {
if tableExists(dialect, db, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if tableExists(dialect, db, "goose_db_version") {
return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables)
}
}
if _, err := db.Exec(createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to check database owner: %w", err)
} else if owner != databaseOwner {
return fmt.Errorf("%w %s", ErrNotOwned, owner)
}
return nil
}
func SetVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec("DELETE FROM version")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version)
return err
}
func execMany(tx *sql.Tx, queries ...string) error {
for _, query := range queries {
_, err := tx.Exec(query)
if err != nil {
return err
}
}
return nil
}
func Run(log log.Logger, dialectName string, db *sql.DB) error {
var dialect Dialect
switch strings.ToLower(dialectName) {
case "postgres":
dialect = Postgres
case "sqlite3":
dialect = SQLite
default:
return fmt.Errorf("unknown dialect %s", dialectName)
}
err := CheckDatabaseOwner(dialect, db)
if err != nil {
return err
}
version, err := GetVersion(db)
if err != nil {
return err
}
if version > NumberOfUpgrades {
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades)
}
log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades)
for i, upgradeItem := range upgrades[version:] {
if upgradeItem.fn == nil {
continue
}
log.Infofln("Upgrading database to v%d: %s", version+i+1, upgradeItem.message)
var tx *sql.Tx
tx, err = db.Begin()
if err != nil {
return err
}
err = upgradeItem.fn(tx, context{dialect, db, log})
if err != nil {
return err
}
err = SetVersion(tx, version+i+1)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
}
return nil
func init() {
Table.Register(-1, 43, "Unsupported version", func(tx *sql.Tx, database *dbutil.Database) error {
return errors.New("please upgrade to mautrix-whatsapp v0.4.0 before upgrading to a newer version")
})
Table.RegisterFS(rawUpgrades)
}

View File

@ -24,6 +24,7 @@ import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
@ -89,7 +90,7 @@ type User struct {
inSpaceCacheLock sync.Mutex
}
func (user *User) Scan(row Scannable) *User {
func (user *User) Scan(row dbutil.Scannable) *User {
var username, timezone sql.NullString
var device, agent sql.NullByte
var phoneLastSeen, phoneLastPinged sql.NullInt64

View File

@ -50,9 +50,9 @@ func (portal *Portal) ScheduleDisappearing() {
}
}
func (bridge *Bridge) SleepAndDeleteUpcoming() {
for _, msg := range bridge.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
portal := bridge.GetPortalByMXID(msg.RoomID)
func (br *WABridge) SleepAndDeleteUpcoming() {
for _, msg := range br.DB.DisappearingMessage.GetUpcomingScheduled(1 * time.Hour) {
portal := br.GetPortalByMXID(msg.RoomID)
if portal == nil {
msg.Delete()
} else {

View File

@ -43,14 +43,6 @@ appservice:
max_conn_idle_time: null
max_conn_lifetime: null
# Settings for provisioning API
provisioning:
# Prefix for the provisioning API paths.
prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate", a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# The unique ID of this appservice.
id: whatsapp
# Appservice bot details.
@ -317,6 +309,14 @@ bridge:
# Verification by the bridge is not yet implemented.
require_verification: true
# Settings for provisioning API
provisioning:
# Prefix for the provisioning API paths.
prefix: /_matrix/provision
# Shared secret for authentication. If set to "generate", a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled.
shared_secret: generate
# Permissions for using the bridge.
# Permitted values:
# relay - Talk through the relaybot (if enabled), no access otherwise

View File

@ -37,7 +37,7 @@ var codeBlockRegex = regexp.MustCompile("```(?:.|\n)+?```")
const mentionedJIDsContextKey = "net.maunium.whatsapp.mentioned_jids"
type Formatter struct {
bridge *Bridge
bridge *WABridge
matrixHTMLParser *format.HTMLParser
@ -46,7 +46,7 @@ type Formatter struct {
waReplFuncText map[*regexp.Regexp]func(string) string
}
func NewFormatter(bridge *Bridge) *Formatter {
func NewFormatter(bridge *WABridge) *Formatter {
formatter := &Formatter{
bridge: bridge,
matrixHTMLParser: &format.HTMLParser{

7
go.mod
View File

@ -14,10 +14,8 @@ require (
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9
golang.org/x/net v0.0.0-20220513224357-95641704303c
google.golang.org/protobuf v1.28.0
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99
maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.3.2
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5
)
require (
@ -37,7 +35,8 @@ require (
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 // indirect
golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886 // indirect
golang.org/x/text v0.3.7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0 // indirect
maunium.net/go/mauflag v1.0.0 // indirect
)
// Exclude some things that cause go.sum to explode

9
go.sum
View File

@ -99,14 +99,13 @@ google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc=
gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA=
gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1 h1:+KEF+nSuBfHWsfQRz92YP/DdSLbComLoXCXgcrH6WRU=
maunium.net/go/mautrix v0.11.1-0.20220518174602-87d2cd49a4d1/go.mod h1:K29EcHwsNg6r7fMfwvi0GHQ9o5wSjqB9+Q8RjCIQEjA=
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5 h1:7ZORg2h+lflc1HwjTKCXZnykauXD+wzbW+VDknbv6SU=
maunium.net/go/mautrix v0.11.1-0.20220521215033-d578d1a610d5/go.mod h1:oma8o6Y/5jcViBlDbX7tp1ajP2XP+b78h8twdI+zKI0=

515
main.go
View File

@ -18,43 +18,26 @@ package main
import (
_ "embed"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"google.golang.org/protobuf/proto"
"go.mau.fi/whatsmeow"
waProto "go.mau.fi/whatsmeow/binary/proto"
"go.mau.fi/whatsmeow/store"
"go.mau.fi/whatsmeow/store/sqlstore"
"go.mau.fi/whatsmeow/types"
"google.golang.org/protobuf/proto"
flag "maunium.net/go/mauflag"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/configupgrade"
"maunium.net/go/mautrix-whatsapp/config"
"maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/database/upgrades"
)
// The name and repo URL of the bridge.
var (
Name = "mautrix-whatsapp"
URL = "https://github.com/mautrix/whatsapp"
)
// Information to find out exactly which commit the bridge was built from.
@ -65,120 +48,19 @@ var (
BuildTime = "unknown"
)
var (
// Version is the version number of the bridge. Changed manually when making a release.
Version = "0.4.0"
// WAVersion is the version number exposed to WhatsApp. Filled in init()
WAVersion = ""
// VersionString is the bridge version, plus commit information. Filled in init() using the build-time values.
VersionString = ""
)
//go:embed example-config.yaml
var ExampleConfig string
func init() {
if len(Tag) > 0 && Tag[0] == 'v' {
Tag = Tag[1:]
}
if Tag != Version {
suffix := ""
if !strings.HasSuffix(Version, "+dev") {
suffix = "+dev"
}
if len(Commit) > 8 {
Version = fmt.Sprintf("%s%s.%s", Version, suffix, Commit[:8])
} else {
Version = fmt.Sprintf("%s%s.unknown", Version, suffix)
}
}
mautrix.DefaultUserAgent = fmt.Sprintf("mautrix-whatsapp/%s %s", Version, mautrix.DefaultUserAgent)
WAVersion = strings.FieldsFunc(Version, func(r rune) bool { return r == '-' || r == '+' })[0]
VersionString = fmt.Sprintf("%s %s (%s)", Name, Version, BuildTime)
config.ExampleConfig = ExampleConfig
}
var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String()
var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool()
var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String()
var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool()
var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool()
var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool()
var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool()
var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool()
var wantHelp, _ = flag.MakeHelpFlag()
func (bridge *Bridge) GenerateRegistration() {
if *dontSaveConfig {
// We need to save the generated as_token and hs_token in the config
_, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration")
os.Exit(5)
}
reg, err := bridge.Config.NewRegistration()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to generate registration:", err)
os.Exit(20)
}
err = reg.Save(*registrationPath)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err)
os.Exit(21)
}
err = config.Mutate(*configPath, func(helper *configupgrade.Helper) {
helper.Set(configupgrade.Str, bridge.Config.AppService.ASToken, "appservice", "as_token")
helper.Set(configupgrade.Str, bridge.Config.AppService.HSToken, "appservice", "hs_token")
})
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err)
os.Exit(22)
}
fmt.Println("Registration generated. Add the path to the registration to your Synapse config, restart it, then start the bridge.")
os.Exit(0)
}
func (bridge *Bridge) MigrateDatabase() {
oldDB, err := database.New(config.DatabaseConfig{Type: flag.Arg(0), URI: flag.Arg(1)}, log.DefaultLogger)
if err != nil {
fmt.Println("Failed to open old database:", err)
os.Exit(30)
}
err = oldDB.Init()
if err != nil {
fmt.Println("Failed to upgrade old database:", err)
os.Exit(31)
}
newDB, err := database.New(bridge.Config.AppService.Database, log.DefaultLogger)
if err != nil {
fmt.Println("Failed to open new database:", err)
os.Exit(32)
}
err = newDB.Init()
if err != nil {
fmt.Println("Failed to upgrade new database:", err)
os.Exit(33)
}
database.Migrate(oldDB, newDB)
}
type Bridge struct {
AS *appservice.AppService
EventProcessor *appservice.EventProcessor
MatrixHandler *MatrixHandler
Config *config.Config
DB *database.Database
Log log.Logger
StateStore *database.SQLStateStore
Provisioning *ProvisioningAPI
Bot *appservice.IntentAPI
Formatter *Formatter
Crypto Crypto
Metrics *MetricsHandler
WAContainer *sqlstore.Container
type WABridge struct {
bridge.Bridge
MatrixHandler *MatrixHandler
Config *config.Config
DB *database.Database
Provisioning *ProvisioningAPI
Formatter *Formatter
Metrics *MetricsHandler
WAContainer *sqlstore.Container
WAVersion string
usersByMXID map[id.UserID]*User
usersByUsername map[string]*User
@ -195,111 +77,32 @@ type Bridge struct {
puppetsLock sync.Mutex
}
type Crypto interface {
HandleMemberEvent(*event.Event)
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
Start()
Stop()
}
func (bridge *Bridge) ensureConnection() {
for {
versions, err := bridge.Bot.Versions()
if err != nil {
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
continue
}
if !versions.ContainsGreaterOrEqual(mautrix.SpecV11) {
bridge.Log.Warnfln("Server isn't advertising modern spec versions")
}
resp, err := bridge.Bot.Whoami()
if err != nil {
if errors.Is(err, mautrix.MUnknownToken) {
bridge.Log.Fatalln("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
os.Exit(16)
} else if errors.Is(err, mautrix.MExclusive) {
bridge.Log.Fatalln("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
os.Exit(16)
}
bridge.Log.Errorfln("Failed to connect to homeserver: %v. Retrying in 10 seconds...", err)
time.Sleep(10 * time.Second)
} else if resp.UserID != bridge.Bot.UserID {
bridge.Log.Fatalln("Unexpected user ID in whoami call: got %s, expected %s", resp.UserID, bridge.Bot.UserID)
os.Exit(17)
} else {
break
}
}
}
func (bridge *Bridge) Init() {
var err error
bridge.AS, err = bridge.Config.MakeAppService()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to initialize AppService:", err)
os.Exit(11)
}
_, _ = bridge.AS.Init()
bridge.Log = log.Create()
bridge.Config.Logging.Configure(bridge.Log)
log.DefaultLogger = bridge.Log.(*log.BasicLogger)
if len(bridge.Config.Logging.FileNameFormat) > 0 {
err = log.OpenFile()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to open log file:", err)
os.Exit(12)
}
}
bridge.AS.Log = log.Sub("Matrix")
bridge.Bot = bridge.AS.BotIntent()
bridge.Log.Infoln("Initializing", VersionString)
bridge.Log.Debugln("Initializing database connection")
bridge.DB, err = database.New(bridge.Config.AppService.Database, bridge.Log)
if err != nil {
bridge.Log.Fatalln("Failed to initialize database connection:", err)
os.Exit(14)
}
bridge.Log.Debugln("Initializing state store")
bridge.StateStore = database.NewSQLStateStore(bridge.DB)
bridge.AS.StateStore = bridge.StateStore
Segment.log = bridge.Log.Sub("Segment")
Segment.key = bridge.Config.SegmentKey
func (br *WABridge) Init() {
Segment.log = br.Log.Sub("Segment")
Segment.key = br.Config.SegmentKey
if Segment.IsEnabled() {
Segment.log.Infoln("Segment metrics are enabled")
}
bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil)
bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError
br.DB = database.New(br.Bridge.DB)
br.WAContainer = sqlstore.NewWithDB(br.DB.DB, br.DB.Dialect.String(), nil)
br.WAContainer.DatabaseErrorHandler = br.DB.HandleSignalStoreError
ss := bridge.Config.AppService.Provisioning.SharedSecret
ss := br.Config.Bridge.Provisioning.SharedSecret
if len(ss) > 0 && ss != "disable" {
bridge.Provisioning = &ProvisioningAPI{bridge: bridge}
br.Provisioning = &ProvisioningAPI{bridge: br}
}
bridge.Log.Debugln("Initializing Matrix event processor")
bridge.EventProcessor = appservice.NewEventProcessor(bridge.AS)
bridge.Log.Debugln("Initializing Matrix event handler")
bridge.MatrixHandler = NewMatrixHandler(bridge)
bridge.Formatter = NewFormatter(bridge)
bridge.Crypto = NewCryptoHelper(bridge)
bridge.Metrics = NewMetricsHandler(bridge.Config.Metrics.Listen, bridge.Log.Sub("Metrics"), bridge.DB)
br.Log.Debugln("Initializing Matrix event handler")
br.MatrixHandler = NewMatrixHandler(br)
br.Formatter = NewFormatter(br)
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
store.BaseClientPayload.UserAgent.OsVersion = proto.String(WAVersion)
store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(WAVersion)
store.CompanionProps.Os = proto.String(bridge.Config.WhatsApp.OSName)
store.CompanionProps.RequireFullSync = proto.Bool(bridge.Config.Bridge.HistorySync.RequestFullSync)
versionParts := strings.Split(WAVersion, ".")
store.BaseClientPayload.UserAgent.OsVersion = proto.String(br.WAVersion)
store.BaseClientPayload.UserAgent.OsBuildNumber = proto.String(br.WAVersion)
store.CompanionProps.Os = proto.String(br.Config.WhatsApp.OSName)
store.CompanionProps.RequireFullSync = proto.Bool(br.Config.Bridge.HistorySync.RequestFullSync)
versionParts := strings.Split(br.WAVersion, ".")
if len(versionParts) > 2 {
primary, _ := strconv.Atoi(versionParts[0])
secondary, _ := strconv.Atoi(versionParts[1])
@ -308,161 +111,107 @@ func (bridge *Bridge) Init() {
store.CompanionProps.Version.Secondary = proto.Uint32(uint32(secondary))
store.CompanionProps.Version.Tertiary = proto.Uint32(uint32(tertiary))
}
platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(bridge.Config.WhatsApp.BrowserName)]
platformID, ok := waProto.CompanionProps_CompanionPropsPlatformType_value[strings.ToUpper(br.Config.WhatsApp.BrowserName)]
if ok {
store.CompanionProps.PlatformType = waProto.CompanionProps_CompanionPropsPlatformType(platformID).Enum()
}
}
func (bridge *Bridge) Start() {
bridge.Log.Debugln("Running database upgrades")
err := bridge.DB.Init()
if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) {
bridge.Log.Fatalln("Failed to initialize database:", err)
if errors.Is(err, upgrades.ErrForeignTables) {
bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error")
} else if errors.Is(err, upgrades.ErrNotOwned) {
bridge.Log.Infoln("Sharing the same database with different programs is not supported")
} else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) {
bridge.Log.Infoln("Downgrading the bridge is not supported")
}
func (br *WABridge) Start() {
err := br.WAContainer.Upgrade()
if err != nil {
br.Log.Fatalln("Failed to upgrade whatsmeow database: %v", err)
os.Exit(15)
}
bridge.Log.Debugln("Checking connection to homeserver")
bridge.ensureConnection()
if bridge.Crypto != nil {
err = bridge.Crypto.Init()
if err != nil {
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
os.Exit(19)
}
if br.Provisioning != nil {
br.Log.Debugln("Initializing provisioning API")
br.Provisioning.Init()
}
if bridge.Provisioning != nil {
bridge.Log.Debugln("Initializing provisioning API")
bridge.Provisioning.Init()
}
bridge.Log.Debugln("Starting application service HTTP server")
go bridge.AS.Start()
bridge.Log.Debugln("Starting event processor")
go bridge.EventProcessor.Start()
go bridge.CheckWhatsAppUpdate()
go bridge.UpdateBotProfile()
if bridge.Crypto != nil {
go bridge.Crypto.Start()
}
go bridge.StartUsers()
if bridge.Config.Metrics.Enabled {
go bridge.Metrics.Start()
go br.CheckWhatsAppUpdate()
go br.StartUsers()
if br.Config.Metrics.Enabled {
go br.Metrics.Start()
}
if bridge.Config.Bridge.ResendBridgeInfo {
go bridge.ResendBridgeInfo()
if br.Config.Bridge.ResendBridgeInfo {
go br.ResendBridgeInfo()
}
go bridge.Loop()
bridge.AS.Ready = true
go br.Loop()
}
func (bridge *Bridge) CheckWhatsAppUpdate() {
bridge.Log.Debugfln("Checking for WhatsApp web update")
func (br *WABridge) CheckWhatsAppUpdate() {
br.Log.Debugfln("Checking for WhatsApp web update")
resp, err := whatsmeow.CheckUpdate(http.DefaultClient)
if err != nil {
bridge.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
br.Log.Warnfln("Failed to check for WhatsApp web update: %v", err)
return
}
if store.GetWAVersion() == resp.ParsedVersion {
bridge.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
br.Log.Debugfln("Bridge is using latest WhatsApp web protocol")
} else if store.GetWAVersion().LessThan(resp.ParsedVersion) {
if resp.IsBelowHard || resp.IsBroken {
bridge.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.Log.Warnfln("Bridge is using outdated WhatsApp web protocol and probably doesn't work anymore (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else if resp.IsBelowSoft {
bridge.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.Log.Infofln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
} else {
bridge.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
br.Log.Debugfln("Bridge is using outdated WhatsApp web protocol (%s, latest is %s)", store.GetWAVersion(), resp.ParsedVersion)
}
} else {
bridge.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
br.Log.Debugfln("Bridge is using newer than latest WhatsApp web protocol")
}
}
func (bridge *Bridge) Loop() {
func (br *WABridge) Loop() {
for {
bridge.SleepAndDeleteUpcoming()
br.SleepAndDeleteUpcoming()
time.Sleep(1 * time.Hour)
bridge.WarnUsersAboutDisconnection()
br.WarnUsersAboutDisconnection()
}
}
func (bridge *Bridge) WarnUsersAboutDisconnection() {
bridge.usersLock.Lock()
for _, user := range bridge.usersByUsername {
func (br *WABridge) WarnUsersAboutDisconnection() {
br.usersLock.Lock()
for _, user := range br.usersByUsername {
if user.IsConnected() && !user.PhoneRecentlySeen(true) {
go user.sendPhoneOfflineWarning()
}
}
bridge.usersLock.Unlock()
br.usersLock.Unlock()
}
func (bridge *Bridge) ResendBridgeInfo() {
if *dontSaveConfig {
bridge.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
} else {
err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
})
if err != nil {
bridge.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
}
}
bridge.Log.Infoln("Re-sending bridge info state event to all portals")
for _, portal := range bridge.GetAllPortals() {
portal.UpdateBridgeInfo()
}
bridge.Log.Infoln("Finished re-sending bridge info state events")
func (br *WABridge) ResendBridgeInfo() {
// FIXME
//if *dontSaveConfig {
// br.Log.Warnln("Not setting resend_bridge_info to false in config due to --no-update flag")
//} else {
// err := config.Mutate(*configPath, func(helper *configupgrade.Helper) {
// helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info")
// })
// if err != nil {
// br.Log.Errorln("Failed to save config after setting resend_bridge_info to false:", err)
// }
//}
//br.Log.Infoln("Re-sending bridge info state event to all portals")
//for _, portal := range br.GetAllPortals() {
// portal.UpdateBridgeInfo()
//}
//br.Log.Infoln("Finished re-sending bridge info state events")
}
func (bridge *Bridge) UpdateBotProfile() {
bridge.Log.Debugln("Updating bot profile")
botConfig := &bridge.Config.AppService.Bot
var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" {
err = bridge.Bot.SetAvatarURL(mxc)
} else if len(botConfig.Avatar) > 0 {
mxc, err = id.ParseContentURI(botConfig.Avatar)
if err == nil {
err = bridge.Bot.SetAvatarURL(mxc)
}
botConfig.ParsedAvatar = mxc
}
if err != nil {
bridge.Log.Warnln("Failed to update bot avatar:", err)
}
if botConfig.Displayname == "remove" {
err = bridge.Bot.SetDisplayName("")
} else if len(botConfig.Displayname) > 0 {
err = bridge.Bot.SetDisplayName(botConfig.Displayname)
}
if err != nil {
bridge.Log.Warnln("Failed to update bot displayname:", err)
}
}
func (bridge *Bridge) StartUsers() {
bridge.Log.Debugln("Starting users")
func (br *WABridge) StartUsers() {
br.Log.Debugln("Starting users")
foundAnySessions := false
for _, user := range bridge.GetAllUsers() {
for _, user := range br.GetAllUsers() {
if !user.JID.IsEmpty() {
foundAnySessions = true
}
go user.Connect()
}
if !foundAnySessions {
bridge.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
br.sendGlobalBridgeState(BridgeState{StateEvent: StateUnconfigured}.fill(nil))
}
bridge.Log.Debugln("Starting custom puppets")
for _, loopuppet := range bridge.GetAllPuppetsWithCustomMXID() {
br.Log.Debugln("Starting custom puppets")
for _, loopuppet := range br.GetAllPuppetsWithCustomMXID() {
go func(puppet *Puppet) {
puppet.log.Debugln("Starting custom puppet", puppet.CustomMXID)
err := puppet.StartCustomMXID(true)
@ -473,80 +222,37 @@ func (bridge *Bridge) StartUsers() {
}
}
func (bridge *Bridge) Stop() {
if bridge.Crypto != nil {
bridge.Crypto.Stop()
func (br *WABridge) Stop() {
if br.Crypto != nil {
br.Crypto.Stop()
}
bridge.AS.Stop()
bridge.Metrics.Stop()
bridge.EventProcessor.Stop()
for _, user := range bridge.usersByUsername {
br.AS.Stop()
br.Metrics.Stop()
br.EventProcessor.Stop()
for _, user := range br.usersByUsername {
if user.Client == nil {
continue
}
bridge.Log.Debugln("Disconnecting", user.MXID)
br.Log.Debugln("Disconnecting", user.MXID)
user.Client.Disconnect()
close(user.historySyncs)
}
}
func (bridge *Bridge) Main() {
configData, upgraded, err := config.Upgrade(*configPath, !*dontSaveConfig)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err)
if configData == nil {
os.Exit(10)
}
func (br *WABridge) GetExampleConfig() string {
return ExampleConfig
}
func (br *WABridge) GetConfigPtr() interface{} {
br.Config = &config.Config{
BaseConfig: &br.Bridge.Config,
}
bridge.Config, err = config.Load(configData, upgraded)
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err)
os.Exit(10)
}
if *generateRegistration {
bridge.GenerateRegistration()
return
} else if *migrateFrom {
bridge.MigrateDatabase()
return
}
bridge.Init()
bridge.Log.Infoln("Bridge initialization complete, starting...")
bridge.Start()
bridge.Log.Infoln("Bridge started!")
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
bridge.Log.Infoln("Interrupt received, stopping...")
bridge.Stop()
bridge.Log.Infoln("Bridge stopped.")
os.Exit(0)
br.Config.BaseConfig.Bridge = &br.Config.Bridge
return br.Config
}
func main() {
flag.SetHelpTitles(
"mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.",
"mautrix-whatsapp [-h] [-c <path>] [-r <path>] [-g] [--migrate-db <source type> <source uri>]")
err := flag.Parse()
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
flag.PrintHelp()
os.Exit(1)
} else if *wantHelp {
flag.PrintHelp()
os.Exit(0)
} else if *version {
fmt.Println(VersionString)
return
}
upgrades.IgnoreForeignTables = *ignoreForeignTables
(&Bridge{
br := &WABridge{
usersByMXID: make(map[id.UserID]*User),
usersByUsername: make(map[string]*User),
spaceRooms: make(map[id.RoomID]*User),
@ -555,5 +261,24 @@ func main() {
portalsByJID: make(map[database.PortalKey]*Portal),
puppets: make(map[types.JID]*Puppet),
puppetsByCustomMXID: make(map[id.UserID]*Puppet),
}).Main()
}
br.Bridge = bridge.Bridge{
Name: "mautrix-whatsapp",
URL: "https://github.com/mautrix/whatsapp",
Description: "A Matrix-WhatsApp puppeting bridge.",
Version: "0.4.0",
ProtocolName: "WhatsApp",
ConfigUpgrader: &configupgrade.StructUpgrader{
SimpleUpgrader: configupgrade.SimpleUpgrader(config.DoUpgrade),
Blocks: config.SpacedBlocks,
Base: ExampleConfig,
},
Child: br,
}
br.InitVersion(Tag, Commit, BuildTime)
br.WAVersion = strings.FieldsFunc(br.Version, func(r rune) bool { return r == '-' || r == '+' })[0]
br.Main()
}

View File

@ -28,6 +28,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@ -36,13 +37,13 @@ import (
)
type MatrixHandler struct {
bridge *Bridge
bridge *WABridge
as *appservice.AppService
log maulogger.Logger
cmd *CommandHandler
}
func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
func NewMatrixHandler(bridge *WABridge) *MatrixHandler {
handler := &MatrixHandler{
bridge: bridge,
as: bridge.AS,
@ -362,7 +363,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
decryptionRetryCount := 0
if errors.Is(err, NoSessionFound) {
if errors.Is(err, bridge.NoSessionFound) {
content := evt.Content.AsEncrypted()
mx.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds()))
mx.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount)

View File

@ -1,17 +0,0 @@
//go:build !cgo || nocrypto
package main
import (
"errors"
)
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
}
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
return nil
}
var NoSessionFound = errors.New("nil")

View File

@ -45,6 +45,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
@ -69,68 +70,72 @@ const PrivateChatTopic = "WhatsApp private chat"
var ErrStatusBroadcastDisabled = errors.New("status bridging is disabled")
func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
bridge.portalsLock.Lock()
defer bridge.portalsLock.Unlock()
portal, ok := bridge.portalsByMXID[mxid]
func (br *WABridge) GetPortalByMXID(mxid id.RoomID) *Portal {
br.portalsLock.Lock()
defer br.portalsLock.Unlock()
portal, ok := br.portalsByMXID[mxid]
if !ok {
return bridge.loadDBPortal(bridge.DB.Portal.GetByMXID(mxid), nil)
return br.loadDBPortal(br.DB.Portal.GetByMXID(mxid), nil)
}
return portal
}
func (bridge *Bridge) GetPortalByJID(key database.PortalKey) *Portal {
bridge.portalsLock.Lock()
defer bridge.portalsLock.Unlock()
portal, ok := bridge.portalsByJID[key]
func (br *WABridge) GetIPortalByMXID(mxid id.RoomID) bridge.Portal {
return br.GetPortalByMXID(mxid)
}
func (br *WABridge) GetPortalByJID(key database.PortalKey) *Portal {
br.portalsLock.Lock()
defer br.portalsLock.Unlock()
portal, ok := br.portalsByJID[key]
if !ok {
return bridge.loadDBPortal(bridge.DB.Portal.GetByJID(key), &key)
return br.loadDBPortal(br.DB.Portal.GetByJID(key), &key)
}
return portal
}
func (bridge *Bridge) GetAllPortals() []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAll())
func (br *WABridge) GetAllPortals() []*Portal {
return br.dbPortalsToPortals(br.DB.Portal.GetAll())
}
func (bridge *Bridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllForUser(userID))
func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal {
return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID))
}
func (bridge *Bridge) GetAllPortalsByJID(jid types.JID) []*Portal {
return bridge.dbPortalsToPortals(bridge.DB.Portal.GetAllByJID(jid))
func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal {
return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid))
}
func (bridge *Bridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
bridge.portalsLock.Lock()
defer bridge.portalsLock.Unlock()
func (br *WABridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal {
br.portalsLock.Lock()
defer br.portalsLock.Unlock()
output := make([]*Portal, len(dbPortals))
for index, dbPortal := range dbPortals {
if dbPortal == nil {
continue
}
portal, ok := bridge.portalsByJID[dbPortal.Key]
portal, ok := br.portalsByJID[dbPortal.Key]
if !ok {
portal = bridge.loadDBPortal(dbPortal, nil)
portal = br.loadDBPortal(dbPortal, nil)
}
output[index] = portal
}
return output
}
func (bridge *Bridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
func (br *WABridge) loadDBPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
if dbPortal == nil {
if key == nil {
return nil
}
dbPortal = bridge.DB.Portal.New()
dbPortal = br.DB.Portal.New()
dbPortal.Key = *key
dbPortal.Insert()
}
portal := bridge.NewPortal(dbPortal)
bridge.portalsByJID[portal.Key] = portal
portal := br.NewPortal(dbPortal)
br.portalsByJID[portal.Key] = portal
if len(portal.MXID) > 0 {
bridge.portalsByMXID[portal.MXID] = portal
br.portalsByMXID[portal.MXID] = portal
}
return portal
}
@ -139,14 +144,14 @@ func (portal *Portal) GetUsers() []*User {
return nil
}
func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
func (br *WABridge) newBlankPortal(key database.PortalKey) *Portal {
portal := &Portal{
bridge: bridge,
log: bridge.Log.Sub(fmt.Sprintf("Portal/%s", key)),
bridge: br,
log: br.Log.Sub(fmt.Sprintf("Portal/%s", key)),
messages: make(chan PortalMessage, bridge.Config.Bridge.PortalMessageBuffer),
matrixMessages: make(chan PortalMatrixMessage, bridge.Config.Bridge.PortalMessageBuffer),
mediaRetries: make(chan PortalMediaRetry, bridge.Config.Bridge.PortalMessageBuffer),
messages: make(chan PortalMessage, br.Config.Bridge.PortalMessageBuffer),
matrixMessages: make(chan PortalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
mediaRetries: make(chan PortalMediaRetry, br.Config.Bridge.PortalMessageBuffer),
mediaErrorCache: make(map[types.MessageID]*FailedMediaMeta),
}
@ -154,15 +159,15 @@ func (bridge *Bridge) newBlankPortal(key database.PortalKey) *Portal {
return portal
}
func (bridge *Bridge) NewManualPortal(key database.PortalKey) *Portal {
portal := bridge.newBlankPortal(key)
portal.Portal = bridge.DB.Portal.New()
func (br *WABridge) NewManualPortal(key database.PortalKey) *Portal {
portal := br.newBlankPortal(key)
portal.Portal = br.DB.Portal.New()
portal.Key = key
return portal
}
func (bridge *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
portal := bridge.newBlankPortal(dbPortal.Key)
func (br *WABridge) NewPortal(dbPortal *database.Portal) *Portal {
portal := br.newBlankPortal(dbPortal.Key)
portal.Portal = dbPortal
return portal
}
@ -203,7 +208,7 @@ type recentlyHandledWrapper struct {
type Portal struct {
*database.Portal
bridge *Bridge
bridge *WABridge
log log.Logger
roomCreateLock sync.Mutex
@ -229,6 +234,10 @@ type Portal struct {
relayUser *User
}
func (portal *Portal) IsEncrypted() bool {
return portal.Encrypted
}
func (portal *Portal) handleMessageLoopItem(msg PortalMessage) {
if len(portal.MXID) == 0 {
if msg.fake == nil && msg.undecryptable == nil && (msg.evt == nil || !containsSupportedMessage(msg.evt.Message)) {

View File

@ -43,15 +43,15 @@ import (
)
type ProvisioningAPI struct {
bridge *Bridge
bridge *WABridge
log log.Logger
}
func (prov *ProvisioningAPI) Init() {
prov.log = prov.bridge.Log.Sub("Provisioning")
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter()
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.Bridge.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.Bridge.Provisioning.Prefix).Subrouter()
r.Use(prov.AuthMiddleware)
r.HandleFunc("/v1/ping", prov.Ping).Methods(http.MethodGet)
r.HandleFunc("/v1/login", prov.Login).Methods(http.MethodGet)
@ -109,7 +109,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
} else if strings.HasPrefix(auth, "Bearer ") {
auth = auth[len("Bearer "):]
}
if auth != prov.bridge.Config.AppService.Provisioning.SharedSecret {
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Invalid auth token",
"errcode": "M_FORBIDDEN",

View File

@ -39,11 +39,11 @@ import (
var userIDRegex *regexp.Regexp
func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
func (br *WABridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
if userIDRegex == nil {
userIDRegex = regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
bridge.Config.Bridge.FormatUsername("([0-9]+)"),
bridge.Config.Homeserver.Domain))
br.Config.Bridge.FormatUsername("([0-9]+)"),
br.Config.Homeserver.Domain))
}
match := userIDRegex.FindStringSubmatch(string(mxid))
if len(match) == 2 {
@ -53,79 +53,79 @@ func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (jid types.JID, ok bool) {
return
}
func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
jid, ok := bridge.ParsePuppetMXID(mxid)
func (br *WABridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
jid, ok := br.ParsePuppetMXID(mxid)
if !ok {
return nil
}
return bridge.GetPuppetByJID(jid)
return br.GetPuppetByJID(jid)
}
func (bridge *Bridge) GetPuppetByJID(jid types.JID) *Puppet {
func (br *WABridge) GetPuppetByJID(jid types.JID) *Puppet {
jid = jid.ToNonAD()
if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer
} else if jid.Server != types.DefaultUserServer {
return nil
}
bridge.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock()
puppet, ok := bridge.puppets[jid]
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
puppet, ok := br.puppets[jid]
if !ok {
dbPuppet := bridge.DB.Puppet.Get(jid)
dbPuppet := br.DB.Puppet.Get(jid)
if dbPuppet == nil {
dbPuppet = bridge.DB.Puppet.New()
dbPuppet = br.DB.Puppet.New()
dbPuppet.JID = jid
dbPuppet.Insert()
}
puppet = bridge.NewPuppet(dbPuppet)
bridge.puppets[puppet.JID] = puppet
puppet = br.NewPuppet(dbPuppet)
br.puppets[puppet.JID] = puppet
if len(puppet.CustomMXID) > 0 {
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
}
}
return puppet
}
func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
bridge.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock()
puppet, ok := bridge.puppetsByCustomMXID[mxid]
func (br *WABridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
puppet, ok := br.puppetsByCustomMXID[mxid]
if !ok {
dbPuppet := bridge.DB.Puppet.GetByCustomMXID(mxid)
dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid)
if dbPuppet == nil {
return nil
}
puppet = bridge.NewPuppet(dbPuppet)
bridge.puppets[puppet.JID] = puppet
bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
puppet = br.NewPuppet(dbPuppet)
br.puppets[puppet.JID] = puppet
br.puppetsByCustomMXID[puppet.CustomMXID] = puppet
}
return puppet
}
func (bridge *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAllWithCustomMXID())
func (br *WABridge) GetAllPuppetsWithCustomMXID() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAllWithCustomMXID())
}
func (bridge *Bridge) GetAllPuppets() []*Puppet {
return bridge.dbPuppetsToPuppets(bridge.DB.Puppet.GetAll())
func (br *WABridge) GetAllPuppets() []*Puppet {
return br.dbPuppetsToPuppets(br.DB.Puppet.GetAll())
}
func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
bridge.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock()
func (br *WABridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
br.puppetsLock.Lock()
defer br.puppetsLock.Unlock()
output := make([]*Puppet, len(dbPuppets))
for index, dbPuppet := range dbPuppets {
if dbPuppet == nil {
continue
}
puppet, ok := bridge.puppets[dbPuppet.JID]
puppet, ok := br.puppets[dbPuppet.JID]
if !ok {
puppet = bridge.NewPuppet(dbPuppet)
bridge.puppets[dbPuppet.JID] = puppet
puppet = br.NewPuppet(dbPuppet)
br.puppets[dbPuppet.JID] = puppet
if len(dbPuppet.CustomMXID) > 0 {
bridge.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
}
}
output[index] = puppet
@ -133,26 +133,26 @@ func (bridge *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet
return output
}
func (bridge *Bridge) FormatPuppetMXID(jid types.JID) id.UserID {
func (br *WABridge) FormatPuppetMXID(jid types.JID) id.UserID {
return id.NewUserID(
bridge.Config.Bridge.FormatUsername(jid.User),
bridge.Config.Homeserver.Domain)
br.Config.Bridge.FormatUsername(jid.User),
br.Config.Homeserver.Domain)
}
func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
func (br *WABridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
return &Puppet{
Puppet: dbPuppet,
bridge: bridge,
log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
bridge: br,
log: br.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
MXID: bridge.FormatPuppetMXID(dbPuppet.JID),
MXID: br.FormatPuppetMXID(dbPuppet.JID),
}
}
type Puppet struct {
*database.Puppet
bridge *Bridge
bridge *WABridge
log log.Logger
typingIn id.RoomID

75
user.go
View File

@ -35,6 +35,7 @@ import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
@ -56,7 +57,7 @@ type User struct {
Client *whatsmeow.Client
Session *store.Device
bridge *Bridge
bridge *WABridge
log log.Logger
Admin bool
@ -84,38 +85,46 @@ type User struct {
BackfillQueue *BackfillQueue
}
func (bridge *Bridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
_, isPuppet := bridge.ParsePuppetMXID(userID)
if isPuppet || userID == bridge.Bot.UserID {
func (br *WABridge) getUserByMXID(userID id.UserID, onlyIfExists bool) *User {
_, isPuppet := br.ParsePuppetMXID(userID)
if isPuppet || userID == br.Bot.UserID {
return nil
}
bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
user, ok := bridge.usersByMXID[userID]
br.usersLock.Lock()
defer br.usersLock.Unlock()
user, ok := br.usersByMXID[userID]
if !ok {
userIDPtr := &userID
if onlyIfExists {
userIDPtr = nil
}
return bridge.loadDBUser(bridge.DB.User.GetByMXID(userID), userIDPtr)
return br.loadDBUser(br.DB.User.GetByMXID(userID), userIDPtr)
}
return user
}
func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User {
return bridge.getUserByMXID(userID, false)
func (br *WABridge) GetUserByMXID(userID id.UserID) *User {
return br.getUserByMXID(userID, false)
}
func (bridge *Bridge) GetUserByMXIDIfExists(userID id.UserID) *User {
return bridge.getUserByMXID(userID, true)
func (br *WABridge) GetIUserByMXID(userID id.UserID) bridge.User {
return br.getUserByMXID(userID, false)
}
func (bridge *Bridge) GetUserByJID(jid types.JID) *User {
bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
user, ok := bridge.usersByUsername[jid.User]
func (user *User) IsAdmin() bool {
return user.Admin
}
func (br *WABridge) GetUserByMXIDIfExists(userID id.UserID) *User {
return br.getUserByMXID(userID, true)
}
func (br *WABridge) GetUserByJID(jid types.JID) *User {
br.usersLock.Lock()
defer br.usersLock.Unlock()
user, ok := br.usersByUsername[jid.User]
if !ok {
return bridge.loadDBUser(bridge.DB.User.GetByUsername(jid.User), nil)
return br.loadDBUser(br.DB.User.GetByUsername(jid.User), nil)
}
return user
}
@ -137,35 +146,35 @@ func (user *User) removeFromJIDMap(state BridgeState) {
user.sendBridgeState(state)
}
func (bridge *Bridge) GetAllUsers() []*User {
bridge.usersLock.Lock()
defer bridge.usersLock.Unlock()
dbUsers := bridge.DB.User.GetAll()
func (br *WABridge) GetAllUsers() []*User {
br.usersLock.Lock()
defer br.usersLock.Unlock()
dbUsers := br.DB.User.GetAll()
output := make([]*User, len(dbUsers))
for index, dbUser := range dbUsers {
user, ok := bridge.usersByMXID[dbUser.MXID]
user, ok := br.usersByMXID[dbUser.MXID]
if !ok {
user = bridge.loadDBUser(dbUser, nil)
user = br.loadDBUser(dbUser, nil)
}
output[index] = user
}
return output
}
func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
func (br *WABridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
if dbUser == nil {
if mxid == nil {
return nil
}
dbUser = bridge.DB.User.New()
dbUser = br.DB.User.New()
dbUser.MXID = *mxid
dbUser.Insert()
}
user := bridge.NewUser(dbUser)
bridge.usersByMXID[user.MXID] = user
user := br.NewUser(dbUser)
br.usersByMXID[user.MXID] = user
if !user.JID.IsEmpty() {
var err error
user.Session, err = bridge.WAContainer.GetDevice(user.JID)
user.Session, err = br.WAContainer.GetDevice(user.JID)
if err != nil {
user.log.Errorfln("Failed to load user's whatsapp session: %v", err)
} else if user.Session == nil {
@ -174,20 +183,20 @@ func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
user.Update()
} else {
user.Session.Log = &waLogger{user.log.Sub("Session")}
bridge.usersByUsername[user.JID.User] = user
br.usersByUsername[user.JID.User] = user
}
}
if len(user.ManagementRoom) > 0 {
bridge.managementRooms[user.ManagementRoom] = user
br.managementRooms[user.ManagementRoom] = user
}
return user
}
func (bridge *Bridge) NewUser(dbUser *database.User) *User {
func (br *WABridge) NewUser(dbUser *database.User) *User {
user := &User{
User: dbUser,
bridge: bridge,
log: bridge.Log.Sub("User").Sub(string(dbUser.MXID)),
bridge: br,
log: br.Log.Sub("User").Sub(string(dbUser.MXID)),
historySyncs: make(chan *events.HistorySync, 32),
lastPresence: types.PresenceUnavailable,