From dfc5722a8036954482d3ac0afa7c51e62d63995e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 May 2020 20:07:21 +0300 Subject: [PATCH] Move crypto store to main database --- crypto.go | 126 +++--- database/cryptostore.go | 393 +++++++++++++++++++ database/statestore.go | 28 +- database/upgrades/2020-05-09-crypto-store.go | 74 ++++ database/upgrades/upgrades.go | 2 +- go.mod | 2 +- go.sum | 2 + main.go | 14 +- 8 files changed, 561 insertions(+), 80 deletions(-) create mode 100644 database/cryptostore.go create mode 100644 database/upgrades/2020-05-09-crypto-store.go diff --git a/crypto.go b/crypto.go index 17094c6..286b600 100644 --- a/crypto.go +++ b/crypto.go @@ -28,6 +28,7 @@ import ( "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" + "maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -40,13 +41,15 @@ var levelTrace = maulogger.Level{ } type CryptoHelper struct { - bridge *Bridge - client *mautrix.Client - mach *crypto.OlmMachine - log maulogger.Logger + bridge *Bridge + client *mautrix.Client + mach *crypto.OlmMachine + store *database.SQLCryptoStore + log maulogger.Logger + baseLog maulogger.Logger } -func (bridge *Bridge) initCrypto() error { +func NewCryptoHelper(bridge *Bridge) *CryptoHelper { if !bridge.Config.Bridge.Encryption.Allow { bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config") return nil @@ -54,39 +57,60 @@ func (bridge *Bridge) initCrypto() error { bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set") return nil } - bridge.Log.Debugln("Initializing end-to-bridge encryption...") - client, err := bridge.loginBot() - if err != nil { - return err + baseLog := bridge.Log.Sub("Crypto") + return &CryptoHelper{ + bridge: bridge, + log: baseLog.Sub("Helper"), + baseLog: baseLog, } - // TODO put this in the database - cryptoStore, err := crypto.NewGobStore("crypto.gob") +} + +func (helper *CryptoHelper) Init() error { + helper.log.Debugln("Initializing end-to-bridge encryption...") + var err error + helper.client, err = helper.loginBot() if err != nil { return err } - log := bridge.Log.Sub("Crypto") - logger := &cryptoLogger{log} - stateStore := &cryptoStateStore{bridge} - helper := &CryptoHelper{ - bridge: bridge, - client: client, - log: log.Sub("Helper"), - mach: crypto.NewOlmMachine(client, logger, cryptoStore, stateStore), + helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID) + logger := &cryptoLogger{helper.baseLog} + stateStore := &cryptoStateStore{helper.bridge} + helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID) + helper.store.UserID = helper.client.UserID + helper.store.GhostIDFormat = helper.bridge.Config.Bridge.FormatUsername("%") + helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore) + + helper.client.Logger = logger.int.Sub("Bot") + helper.client.Syncer = &cryptoSyncer{helper.mach} + helper.client.Store = &cryptoClientStore{helper.store} + + return helper.mach.Load() +} + +func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) { + deviceID := helper.bridge.DB.FindDeviceID() + if len(deviceID) > 0 { + helper.log.Debugln("Found existing device ID for bot in database:", deviceID) } - - client.Logger = logger.int.Sub("Bot") - client.Syncer = &cryptoSyncer{helper.mach} - // TODO put this in the database too - client.Store = mautrix.NewInMemoryStore() - - err = helper.mach.Load() + mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret)) + mac.Write([]byte(helper.bridge.AS.BotMXID())) + resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{ + Type: "m.login.password", + Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())}, + Password: hex.EncodeToString(mac.Sum(nil)), + DeviceID: deviceID, + InitialDeviceDisplayName: "WhatsApp Bridge", + }) if err != nil { - return err + return nil, err } - - bridge.Crypto = helper - return nil + client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken) + if err != nil { + return nil, err + } + client.DeviceID = resp.DeviceID + return client, nil } func (helper *CryptoHelper) Start() { @@ -101,27 +125,6 @@ func (helper *CryptoHelper) Stop() { helper.client.StopSync() } -func (bridge *Bridge) loginBot() (*mautrix.Client, error) { - mac := hmac.New(sha512.New, []byte(bridge.Config.Bridge.LoginSharedSecret)) - mac.Write([]byte(bridge.AS.BotMXID())) - resp, err := bridge.AS.BotClient().Login(&mautrix.ReqLogin{ - Type: "m.login.password", - Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(bridge.AS.BotMXID())}, - Password: hex.EncodeToString(mac.Sum(nil)), - DeviceID: "WhatsApp Bridge", - InitialDeviceDisplayName: "WhatsApp Bridge", - }) - if err != nil { - return nil, err - } - client, err := mautrix.NewClient(bridge.AS.HomeserverURL, bridge.AS.BotMXID(), resp.AccessToken) - if err != nil { - return nil, err - } - client.DeviceID = "WhatsApp Bridge" - return client, nil -} - func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { return helper.mach.DecryptMegolmEvent(evt) } @@ -133,7 +136,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten return nil, err } helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID) - users, err := helper.bridge.StateStore.GetRoomMemberList(roomID) + users, err := helper.store.GetRoomMembers(roomID) if err != nil { return nil, errors.Wrap(err, "failed to get room member list") } @@ -202,6 +205,25 @@ func (c *cryptoLogger) Trace(message string, args ...interface{}) { c.int.Logfln(levelTrace, message, args...) } +type cryptoClientStore struct { + int *database.SQLCryptoStore +} + +func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {} +func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" } +func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {} +func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil } + +func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) { + c.int.PutNextBatch(nextBatchToken) +} + +func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string { + return c.int.GetNextBatch() +} + +var _ mautrix.Storer = (*cryptoClientStore)(nil) + type cryptoStateStore struct { bridge *Bridge } diff --git a/database/cryptostore.go b/database/cryptostore.go new file mode 100644 index 0000000..8b36216 --- /dev/null +++ b/database/cryptostore.go @@ -0,0 +1,393 @@ +// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. +// Copyright (C) 2020 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/lib/pq" + "github.com/pkg/errors" + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +type SQLCryptoStore struct { + db *Database + log log.Logger + + UserID id.UserID + DeviceID id.DeviceID + SyncToken string + PickleKey []byte + Account *crypto.OlmAccount + + GhostIDFormat string + + OGSLock sync.RWMutex + OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession +} + +var _ crypto.Store = (*SQLCryptoStore)(nil) + +func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore { + return &SQLCryptoStore{ + db: db, + log: db.log.Sub("CryptoStore"), + PickleKey: []byte("maunium.net/go/mautrix-whatsapp"), + DeviceID: deviceID, + + OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession), + } +} + +func (db *Database) FindDeviceID() (deviceID id.DeviceID) { + err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID) + if err != nil && err != sql.ErrNoRows { + db.log.Warnln("Failed to scan device ID:", err) + } + return +} + +func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) { + var rows *sql.Rows + rows, err = store.db.Query(` + SELECT user_id FROM mx_user_profile + WHERE room_id=$1 + AND (membership='join' OR membership='invite') + AND user_id<>$2 + AND user_id NOT LIKE $3 + `, roomID, store.UserID, store.GhostIDFormat) + if err != nil { + return + } + for rows.Next() { + var userID id.UserID + err := rows.Scan(&userID) + if err != nil { + store.log.Warnfln("Failed to scan member in %s: %v", roomID, err) + } else { + members = append(members, userID) + } + } + return +} + +func (store *SQLCryptoStore) Flush() error { + return nil +} + +func (store *SQLCryptoStore) PutNextBatch(nextBatch string) { + store.SyncToken = nextBatch + _, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID) + if err != nil { + store.log.Warnln("Failed to store sync token:", err) + } +} + +func (store *SQLCryptoStore) GetNextBatch() string { + if store.SyncToken == "" { + err := store.db. + QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID). + Scan(&store.SyncToken) + if err != nil && err != sql.ErrNoRows { + store.log.Warnln("Failed to scan sync token:", err) + } + } + return store.SyncToken +} + +func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error { + store.Account = account + bytes := account.Internal.Pickle(store.PickleKey) + var err error + if store.db.dialect == "postgres" { + _, err = store.db.Exec(` + INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4) + ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`, + store.DeviceID, account.Shared, store.SyncToken, bytes) + } else if store.db.dialect == "sqlite3" { + _, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)", + store.DeviceID, account.Shared, store.SyncToken, bytes) + } else { + err = fmt.Errorf("unsupported dialect %s", store.db.dialect) + } + if err != nil { + store.log.Warnln("Failed to store account:", err) + } + return nil +} + +func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) { + if store.Account == nil { + row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID) + acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()} + var accountBytes []byte + err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + err = acc.Internal.Unpickle(accountBytes, store.PickleKey) + if err != nil { + return nil, err + } + store.Account = acc + } + return store.Account, nil +} + +func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { + // TODO this may need to be changed if olm sessions start expiring + var sessionID id.SessionID + err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID) + if err == sql.ErrNoRows { + return false + } + return len(sessionID) > 0 +} + +func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) { + rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key) + if err != nil { + return nil, err + } + list := crypto.OlmSessionList{} + for rows.Next() { + sess := crypto.OlmSession{Internal: *olm.NewBlankSession()} + var sessionBytes []byte + err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime) + if err != nil { + return nil, err + } + err = sess.Internal.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return nil, err + } + list = append(list, &sess) + } + return list, nil +} + +func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) { + row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key) + sess := crypto.OlmSession{Internal: *olm.NewBlankSession()} + var sessionBytes []byte + err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey) +} + +func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error { + sessionBytes := session.Internal.Pickle(store.PickleKey) + _, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)", + session.ID(), key, sessionBytes, session.CreationTime, session.UseTime) + return err +} + +func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error { + sessionBytes := session.Internal.Pickle(store.PickleKey) + forwardingChains := strings.Join(session.ForwardingChains, ",") + _, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)", + sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains) + return err +} + +func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { + var signingKey id.Ed25519 + var sessionBytes []byte + var forwardingChains string + err := store.db.QueryRow(` + SELECT signing_key, session, forwarding_chains + FROM crypto_megolm_inbound_session + WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`, + roomID, senderKey, sessionID, + ).Scan(&signingKey, &sessionBytes, &forwardingChains) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + igs := olm.NewBlankInboundGroupSession() + err = igs.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return nil, err + } + return &crypto.InboundGroupSession{ + Internal: *igs, + SigningKey: signingKey, + SenderKey: senderKey, + RoomID: roomID, + ForwardingChains: strings.Split(forwardingChains, ","), + }, nil +} + +func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error { + store.OGSLock.Lock() + store.OutGroupSessions[roomID] = session + store.OGSLock.Unlock() + return nil +} + +func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) { + store.OGSLock.RLock() + defer store.OGSLock.RUnlock() + return store.OutGroupSessions[roomID], nil +} + +func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error { + store.OGSLock.Lock() + delete(store.OutGroupSessions, roomID) + store.OGSLock.Unlock() + return nil +} + +func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool { + var resultEventID id.EventID + var resultTimestamp int64 + err := store.db.QueryRow( + "SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3", + senderKey, sessionID, index, + ).Scan(&resultEventID, &resultTimestamp) + if err == sql.ErrNoRows { + _, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)", + senderKey, sessionID, index, eventID, timestamp) + if err != nil { + store.log.Warnln("Failed to store message index:", err) + } + return true + } else if err != nil { + store.log.Warnln("Failed to scan message index:", err) + return true + } + if resultEventID != eventID || resultTimestamp != timestamp { + return false + } + return true +} + +func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) { + var ignore id.UserID + err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + + rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID) + if err != nil { + return nil, err + } + data := make(map[id.DeviceID]*crypto.DeviceIdentity) + for rows.Next() { + var identity crypto.DeviceIdentity + err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) + if err != nil { + return nil, err + } + identity.UserID = userID + data[identity.DeviceID] = &identity + } + return data, nil +} + +func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error { + tx, err := store.db.Begin() + if err != nil { + return err + } + + if store.db.dialect == "postgres" { + _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) + } else if store.db.dialect == "sqlite3" { + _, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID) + } else { + err = fmt.Errorf("unsupported dialect %s", store.db.dialect) + } + if err != nil { + return errors.Wrap(err, "failed to add user to tracked users list") + } + + _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID) + if err != nil { + _ = tx.Rollback() + return errors.Wrap(err, "failed to delete old devices") + } + if len(devices) == 0 { + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "failed to commit changes (no devices added)") + } + return nil + } + // TODO do this in batches to avoid too large db queries + values := make([]interface{}, 1, len(devices)*6+1) + values[0] = userID + valueStrings := make([]string, 0, len(devices)) + i := 2 + for deviceID, identity := range devices { + values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) + valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5)) + i += 6 + } + valueString := strings.Join(valueStrings, ",") + _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...) + if err != nil { + _ = tx.Rollback() + return errors.Wrap(err, "failed to insert new devices") + } + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "failed to commit changes") + } + return nil +} + +func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID { + var rows *sql.Rows + var err error + if store.db.dialect == "postgres" { + rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users)) + } else { + rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users) + } + if err != nil { + store.log.Warnln("Failed to filter tracked users:", err) + return users + } + var ptr int + for rows.Next() { + err = rows.Scan(&users[ptr]) + if err != nil { + store.log.Warnln("Failed to tracked user ID:", err) + } else { + ptr++ + } + } + return users[:ptr] +} diff --git a/database/statestore.go b/database/statestore.go index 178d4c1..e76d60f 100644 --- a/database/statestore.go +++ b/database/statestore.go @@ -39,11 +39,13 @@ type SQLStateStore struct { typingLock sync.RWMutex } +var _ appservice.StateStore = (*SQLStateStore)(nil) + func NewSQLStateStore(db *Database) *SQLStateStore { return &SQLStateStore{ TypingStateStore: appservice.NewTypingStateStore(), db: db, - log: log.Sub("StateStore"), + log: db.log.Sub("StateStore"), } } @@ -90,24 +92,6 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*even return members } -func (store *SQLStateStore) GetRoomMemberList(roomID id.RoomID) (members []id.UserID, err error) { - var rows *sql.Rows - rows, err = store.db.Query("SELECT user_id FROM mx_user_profile WHERE room_id=$1", roomID) - if err != nil { - return - } - for rows.Next() { - var userID id.UserID - err := rows.Scan(&userID) - if err != nil { - store.log.Warnfln("Failed to scan member in %s: %v", roomID, err) - } else { - members = append(members, userID) - } - } - return -} - func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID) membership := event.MembershipLeave @@ -138,8 +122,10 @@ func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*e func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) { rows, err := store.db.Query(` - SELECT room_id FROM mx_user_profile WHERE user_id=$2 AND portal.encrypted=true - LEFT JOIN portal WHEN portal.mxid=mx_user_profile.room_id`, userID) + SELECT room_id FROM mx_user_profile + LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id + WHERE user_id=$1 AND portal.encrypted=true + `, userID) if err != nil { store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err) return diff --git a/database/upgrades/2020-05-09-crypto-store.go b/database/upgrades/2020-05-09-crypto-store.go new file mode 100644 index 0000000..529ff9c --- /dev/null +++ b/database/upgrades/2020-05-09-crypto-store.go @@ -0,0 +1,74 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error { + // TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite + _, err := tx.Exec(`CREATE TABLE crypto_account ( + device_id VARCHAR(255) PRIMARY KEY, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL + )`) + if err != nil { + return err + } + _, err = tx.Exec(`CREATE TABLE crypto_message_index ( + sender_key CHAR(43), + session_id VARCHAR(255), + index INTEGER, + event_id VARCHAR(255) NOT NULL, + timestamp BIGINT NOT NULL, + + PRIMARY KEY (sender_key, session_id, index) + )`) + if err != nil { + return err + } + _, err = tx.Exec(`CREATE TABLE crypto_tracked_user ( + user_id VARCHAR(255) PRIMARY KEY + )`) + if err != nil { + return err + } + _, err = tx.Exec(`CREATE TABLE crypto_device ( + user_id VARCHAR(255), + device_id VARCHAR(255), + identity_key CHAR(43) NOT NULL, + signing_key CHAR(43) NOT NULL, + trust SMALLINT NOT NULL, + deleted BOOLEAN NOT NULL, + name VARCHAR(255) NOT NULL, + + PRIMARY KEY (user_id, device_id) + )`) + if err != nil { + return err + } + _, err = tx.Exec(`CREATE TABLE crypto_olm_session ( + session_id CHAR(43) PRIMARY KEY, + sender_key VARCHAR(255) NOT NULL, + session bytea NOT NULL, + created_at timestamp NOT NULL, + last_used timestamp NOT NULL + )`) + if err != nil { + return err + } + _, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session ( + session_id CHAR(43) PRIMARY KEY, + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43) NOT NULL, + room_id VARCHAR(255) NOT NULL, + session bytea NOT NULL, + forwarding_chains bytea NOT NULL + )`) + if err != nil { + return err + } + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index 3126cc7..ec8e6e7 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -28,7 +28,7 @@ type upgrade struct { fn upgradeFunc } -const NumberOfUpgrades = 13 +const NumberOfUpgrades = 14 var upgrades [NumberOfUpgrades]upgrade diff --git a/go.mod b/go.mod index 4534d58..38dc65e 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( gopkg.in/yaml.v2 v2.2.8 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.1.1 - maunium.net/go/mautrix v0.4.0 + maunium.net/go/mautrix v0.4.1 ) replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6 diff --git a/go.sum b/go.sum index cba311b..a109e80 100644 --- a/go.sum +++ b/go.sum @@ -86,3 +86,5 @@ maunium.net/go/mautrix v0.3.7 h1:N0czrZeAwjvBrw2a/B2G6U3EwIYaWpt7OuSslGp8DRc= maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y= maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4= maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho= +maunium.net/go/mautrix v0.4.1 h1:i2lJNT+TE4AAL3cVKUN4jKVRkujCE/oS8aIsj8+7iNE= +maunium.net/go/mautrix v0.4.1/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho= diff --git a/main.go b/main.go index ddcd4b3..13cb13e 100644 --- a/main.go +++ b/main.go @@ -126,6 +126,7 @@ type Crypto interface { HandleMemberEvent(*event.Event) Decrypt(*event.Event) (*event.Event, error) Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error) + Init() error Start() Stop() } @@ -225,11 +226,7 @@ func (bridge *Bridge) Init() { bridge.Log.Debugln("Initializing Matrix event handler") bridge.MatrixHandler = NewMatrixHandler(bridge) bridge.Formatter = NewFormatter(bridge) - err = bridge.initCrypto() - if err != nil { - bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err) - os.Exit(19) - } + bridge.Crypto = NewCryptoHelper(bridge) } func (bridge *Bridge) Start() { @@ -238,6 +235,13 @@ func (bridge *Bridge) Start() { bridge.Log.Fatalln("Failed to initialize database:", err) os.Exit(15) } + if bridge.Crypto != nil { + err := bridge.Crypto.Init() + if err != nil { + bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err) + os.Exit(19) + } + } if bridge.Provisioning != nil { bridge.Log.Debugln("Initializing provisioning API") bridge.Provisioning.Init()