mirror of
https://github.com/tulir/mautrix-whatsapp
synced 2024-11-12 04:52:40 +01:00
Move crypto store to main database
This commit is contained in:
parent
6e50a7c380
commit
dfc5722a80
8 changed files with 561 additions and 80 deletions
120
crypto.go
120
crypto.go
|
@ -28,6 +28,7 @@ import (
|
||||||
"maunium.net/go/maulogger/v2"
|
"maunium.net/go/maulogger/v2"
|
||||||
|
|
||||||
"maunium.net/go/mautrix"
|
"maunium.net/go/mautrix"
|
||||||
|
"maunium.net/go/mautrix-whatsapp/database"
|
||||||
"maunium.net/go/mautrix/crypto"
|
"maunium.net/go/mautrix/crypto"
|
||||||
"maunium.net/go/mautrix/event"
|
"maunium.net/go/mautrix/event"
|
||||||
"maunium.net/go/mautrix/id"
|
"maunium.net/go/mautrix/id"
|
||||||
|
@ -43,10 +44,12 @@ type CryptoHelper struct {
|
||||||
bridge *Bridge
|
bridge *Bridge
|
||||||
client *mautrix.Client
|
client *mautrix.Client
|
||||||
mach *crypto.OlmMachine
|
mach *crypto.OlmMachine
|
||||||
|
store *database.SQLCryptoStore
|
||||||
log maulogger.Logger
|
log maulogger.Logger
|
||||||
|
baseLog maulogger.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) initCrypto() error {
|
func NewCryptoHelper(bridge *Bridge) *CryptoHelper {
|
||||||
if !bridge.Config.Bridge.Encryption.Allow {
|
if !bridge.Config.Bridge.Encryption.Allow {
|
||||||
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
|
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
|
||||||
return nil
|
return nil
|
||||||
|
@ -54,39 +57,60 @@ func (bridge *Bridge) initCrypto() error {
|
||||||
bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set")
|
bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
bridge.Log.Debugln("Initializing end-to-bridge encryption...")
|
baseLog := bridge.Log.Sub("Crypto")
|
||||||
client, err := bridge.loginBot()
|
return &CryptoHelper{
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// TODO put this in the database
|
|
||||||
cryptoStore, err := crypto.NewGobStore("crypto.gob")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log := bridge.Log.Sub("Crypto")
|
|
||||||
logger := &cryptoLogger{log}
|
|
||||||
stateStore := &cryptoStateStore{bridge}
|
|
||||||
helper := &CryptoHelper{
|
|
||||||
bridge: bridge,
|
bridge: bridge,
|
||||||
client: client,
|
log: baseLog.Sub("Helper"),
|
||||||
log: log.Sub("Helper"),
|
baseLog: baseLog,
|
||||||
mach: crypto.NewOlmMachine(client, logger, cryptoStore, stateStore),
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client.Logger = logger.int.Sub("Bot")
|
func (helper *CryptoHelper) Init() error {
|
||||||
client.Syncer = &cryptoSyncer{helper.mach}
|
helper.log.Debugln("Initializing end-to-bridge encryption...")
|
||||||
// TODO put this in the database too
|
var err error
|
||||||
client.Store = mautrix.NewInMemoryStore()
|
helper.client, err = helper.loginBot()
|
||||||
|
|
||||||
err = helper.mach.Load()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.Crypto = helper
|
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
|
||||||
return nil
|
logger := &cryptoLogger{helper.baseLog}
|
||||||
|
stateStore := &cryptoStateStore{helper.bridge}
|
||||||
|
helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID)
|
||||||
|
helper.store.UserID = helper.client.UserID
|
||||||
|
helper.store.GhostIDFormat = helper.bridge.Config.Bridge.FormatUsername("%")
|
||||||
|
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
|
||||||
|
|
||||||
|
helper.client.Logger = logger.int.Sub("Bot")
|
||||||
|
helper.client.Syncer = &cryptoSyncer{helper.mach}
|
||||||
|
helper.client.Store = &cryptoClientStore{helper.store}
|
||||||
|
|
||||||
|
return helper.mach.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
|
||||||
|
deviceID := helper.bridge.DB.FindDeviceID()
|
||||||
|
if len(deviceID) > 0 {
|
||||||
|
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
|
||||||
|
}
|
||||||
|
mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret))
|
||||||
|
mac.Write([]byte(helper.bridge.AS.BotMXID()))
|
||||||
|
resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
|
||||||
|
Type: "m.login.password",
|
||||||
|
Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())},
|
||||||
|
Password: hex.EncodeToString(mac.Sum(nil)),
|
||||||
|
DeviceID: deviceID,
|
||||||
|
InitialDeviceDisplayName: "WhatsApp Bridge",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client.DeviceID = resp.DeviceID
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (helper *CryptoHelper) Start() {
|
func (helper *CryptoHelper) Start() {
|
||||||
|
@ -101,27 +125,6 @@ func (helper *CryptoHelper) Stop() {
|
||||||
helper.client.StopSync()
|
helper.client.StopSync()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) loginBot() (*mautrix.Client, error) {
|
|
||||||
mac := hmac.New(sha512.New, []byte(bridge.Config.Bridge.LoginSharedSecret))
|
|
||||||
mac.Write([]byte(bridge.AS.BotMXID()))
|
|
||||||
resp, err := bridge.AS.BotClient().Login(&mautrix.ReqLogin{
|
|
||||||
Type: "m.login.password",
|
|
||||||
Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(bridge.AS.BotMXID())},
|
|
||||||
Password: hex.EncodeToString(mac.Sum(nil)),
|
|
||||||
DeviceID: "WhatsApp Bridge",
|
|
||||||
InitialDeviceDisplayName: "WhatsApp Bridge",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
client, err := mautrix.NewClient(bridge.AS.HomeserverURL, bridge.AS.BotMXID(), resp.AccessToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
client.DeviceID = "WhatsApp Bridge"
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||||
return helper.mach.DecryptMegolmEvent(evt)
|
return helper.mach.DecryptMegolmEvent(evt)
|
||||||
}
|
}
|
||||||
|
@ -133,7 +136,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
|
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
|
||||||
users, err := helper.bridge.StateStore.GetRoomMemberList(roomID)
|
users, err := helper.store.GetRoomMembers(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to get room member list")
|
return nil, errors.Wrap(err, "failed to get room member list")
|
||||||
}
|
}
|
||||||
|
@ -202,6 +205,25 @@ func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||||
c.int.Logfln(levelTrace, message, args...)
|
c.int.Logfln(levelTrace, message, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cryptoClientStore struct {
|
||||||
|
int *database.SQLCryptoStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||||
|
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||||
|
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
|
||||||
|
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
|
||||||
|
|
||||||
|
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||||
|
c.int.PutNextBatch(nextBatchToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
|
||||||
|
return c.int.GetNextBatch()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ mautrix.Storer = (*cryptoClientStore)(nil)
|
||||||
|
|
||||||
type cryptoStateStore struct {
|
type cryptoStateStore struct {
|
||||||
bridge *Bridge
|
bridge *Bridge
|
||||||
}
|
}
|
||||||
|
|
393
database/cryptostore.go
Normal file
393
database/cryptostore.go
Normal file
|
@ -0,0 +1,393 @@
|
||||||
|
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
|
||||||
|
// Copyright (C) 2020 Tulir Asokan
|
||||||
|
//
|
||||||
|
// This program is free software: you can redistribute it and/or modify
|
||||||
|
// it under the terms of the GNU Affero General Public License as published by
|
||||||
|
// the Free Software Foundation, either version 3 of the License, or
|
||||||
|
// (at your option) any later version.
|
||||||
|
//
|
||||||
|
// This program is distributed in the hope that it will be useful,
|
||||||
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
// GNU Affero General Public License for more details.
|
||||||
|
//
|
||||||
|
// You should have received a copy of the GNU Affero General Public License
|
||||||
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
log "maunium.net/go/maulogger/v2"
|
||||||
|
|
||||||
|
"maunium.net/go/mautrix/crypto"
|
||||||
|
"maunium.net/go/mautrix/crypto/olm"
|
||||||
|
"maunium.net/go/mautrix/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SQLCryptoStore struct {
|
||||||
|
db *Database
|
||||||
|
log log.Logger
|
||||||
|
|
||||||
|
UserID id.UserID
|
||||||
|
DeviceID id.DeviceID
|
||||||
|
SyncToken string
|
||||||
|
PickleKey []byte
|
||||||
|
Account *crypto.OlmAccount
|
||||||
|
|
||||||
|
GhostIDFormat string
|
||||||
|
|
||||||
|
OGSLock sync.RWMutex
|
||||||
|
OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||||
|
|
||||||
|
func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
|
||||||
|
return &SQLCryptoStore{
|
||||||
|
db: db,
|
||||||
|
log: db.log.Sub("CryptoStore"),
|
||||||
|
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
|
||||||
|
DeviceID: deviceID,
|
||||||
|
|
||||||
|
OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
|
||||||
|
err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
db.log.Warnln("Failed to scan device ID:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||||
|
var rows *sql.Rows
|
||||||
|
rows, err = store.db.Query(`
|
||||||
|
SELECT user_id FROM mx_user_profile
|
||||||
|
WHERE room_id=$1
|
||||||
|
AND (membership='join' OR membership='invite')
|
||||||
|
AND user_id<>$2
|
||||||
|
AND user_id NOT LIKE $3
|
||||||
|
`, roomID, store.UserID, store.GhostIDFormat)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for rows.Next() {
|
||||||
|
var userID id.UserID
|
||||||
|
err := rows.Scan(&userID)
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
|
||||||
|
} else {
|
||||||
|
members = append(members, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) Flush() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
|
||||||
|
store.SyncToken = nextBatch
|
||||||
|
_, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnln("Failed to store sync token:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetNextBatch() string {
|
||||||
|
if store.SyncToken == "" {
|
||||||
|
err := store.db.
|
||||||
|
QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
|
||||||
|
Scan(&store.SyncToken)
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
|
store.log.Warnln("Failed to scan sync token:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return store.SyncToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
|
||||||
|
store.Account = account
|
||||||
|
bytes := account.Internal.Pickle(store.PickleKey)
|
||||||
|
var err error
|
||||||
|
if store.db.dialect == "postgres" {
|
||||||
|
_, err = store.db.Exec(`
|
||||||
|
INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
|
||||||
|
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
||||||
|
} else if store.db.dialect == "sqlite3" {
|
||||||
|
_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (deivce_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
|
||||||
|
store.DeviceID, account.Shared, store.SyncToken, bytes)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnln("Failed to store account:", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
|
||||||
|
if store.Account == nil {
|
||||||
|
row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
|
||||||
|
acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
|
||||||
|
var accountBytes []byte
|
||||||
|
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
store.Account = acc
|
||||||
|
}
|
||||||
|
return store.Account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
||||||
|
// TODO this may need to be changed if olm sessions start expiring
|
||||||
|
var sessionID id.SessionID
|
||||||
|
err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return len(sessionID) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
|
||||||
|
rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
list := crypto.OlmSessionList{}
|
||||||
|
for rows.Next() {
|
||||||
|
sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
|
||||||
|
var sessionBytes []byte
|
||||||
|
err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
list = append(list, &sess)
|
||||||
|
}
|
||||||
|
return list, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
|
||||||
|
row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
|
||||||
|
sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
|
||||||
|
var sessionBytes []byte
|
||||||
|
err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
|
||||||
|
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||||
|
_, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
|
||||||
|
session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
|
||||||
|
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||||
|
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||||
|
_, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||||
|
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
|
||||||
|
var signingKey id.Ed25519
|
||||||
|
var sessionBytes []byte
|
||||||
|
var forwardingChains string
|
||||||
|
err := store.db.QueryRow(`
|
||||||
|
SELECT signing_key, session, forwarding_chains
|
||||||
|
FROM crypto_megolm_inbound_session
|
||||||
|
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
|
||||||
|
roomID, senderKey, sessionID,
|
||||||
|
).Scan(&signingKey, &sessionBytes, &forwardingChains)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
igs := olm.NewBlankInboundGroupSession()
|
||||||
|
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &crypto.InboundGroupSession{
|
||||||
|
Internal: *igs,
|
||||||
|
SigningKey: signingKey,
|
||||||
|
SenderKey: senderKey,
|
||||||
|
RoomID: roomID,
|
||||||
|
ForwardingChains: strings.Split(forwardingChains, ","),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
|
||||||
|
store.OGSLock.Lock()
|
||||||
|
store.OutGroupSessions[roomID] = session
|
||||||
|
store.OGSLock.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
|
||||||
|
store.OGSLock.RLock()
|
||||||
|
defer store.OGSLock.RUnlock()
|
||||||
|
return store.OutGroupSessions[roomID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
|
||||||
|
store.OGSLock.Lock()
|
||||||
|
delete(store.OutGroupSessions, roomID)
|
||||||
|
store.OGSLock.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
|
||||||
|
var resultEventID id.EventID
|
||||||
|
var resultTimestamp int64
|
||||||
|
err := store.db.QueryRow(
|
||||||
|
"SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND index=$3",
|
||||||
|
senderKey, sessionID, index,
|
||||||
|
).Scan(&resultEventID, &resultTimestamp)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
_, err := store.db.Exec("INSERT INTO crypto_message_index (sender_key, session_id, index, event_id, timestamp) VALUES ($1, $2, $3, $4, $5)",
|
||||||
|
senderKey, sessionID, index, eventID, timestamp)
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnln("Failed to store message index:", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
} else if err != nil {
|
||||||
|
store.log.Warnln("Failed to scan message index:", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if resultEventID != eventID || resultTimestamp != timestamp {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
|
||||||
|
var ignore id.UserID
|
||||||
|
err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data := make(map[id.DeviceID]*crypto.DeviceIdentity)
|
||||||
|
for rows.Next() {
|
||||||
|
var identity crypto.DeviceIdentity
|
||||||
|
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
identity.UserID = userID
|
||||||
|
data[identity.DeviceID] = &identity
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
|
||||||
|
tx, err := store.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.db.dialect == "postgres" {
|
||||||
|
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||||
|
} else if store.db.dialect == "sqlite3" {
|
||||||
|
_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_users (user_id) VALUES ($1)", userID)
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to add user to tracked users list")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return errors.Wrap(err, "failed to delete old devices")
|
||||||
|
}
|
||||||
|
if len(devices) == 0 {
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to commit changes (no devices added)")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// TODO do this in batches to avoid too large db queries
|
||||||
|
values := make([]interface{}, 1, len(devices)*6+1)
|
||||||
|
values[0] = userID
|
||||||
|
valueStrings := make([]string, 0, len(devices))
|
||||||
|
i := 2
|
||||||
|
for deviceID, identity := range devices {
|
||||||
|
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
|
||||||
|
valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
|
||||||
|
i += 6
|
||||||
|
}
|
||||||
|
valueString := strings.Join(valueStrings, ",")
|
||||||
|
_, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return errors.Wrap(err, "failed to insert new devices")
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to commit changes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if store.db.dialect == "postgres" {
|
||||||
|
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
|
||||||
|
} else {
|
||||||
|
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ($1)", users)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnln("Failed to filter tracked users:", err)
|
||||||
|
return users
|
||||||
|
}
|
||||||
|
var ptr int
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&users[ptr])
|
||||||
|
if err != nil {
|
||||||
|
store.log.Warnln("Failed to tracked user ID:", err)
|
||||||
|
} else {
|
||||||
|
ptr++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users[:ptr]
|
||||||
|
}
|
|
@ -39,11 +39,13 @@ type SQLStateStore struct {
|
||||||
typingLock sync.RWMutex
|
typingLock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ appservice.StateStore = (*SQLStateStore)(nil)
|
||||||
|
|
||||||
func NewSQLStateStore(db *Database) *SQLStateStore {
|
func NewSQLStateStore(db *Database) *SQLStateStore {
|
||||||
return &SQLStateStore{
|
return &SQLStateStore{
|
||||||
TypingStateStore: appservice.NewTypingStateStore(),
|
TypingStateStore: appservice.NewTypingStateStore(),
|
||||||
db: db,
|
db: db,
|
||||||
log: log.Sub("StateStore"),
|
log: db.log.Sub("StateStore"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,24 +92,6 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*even
|
||||||
return members
|
return members
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *SQLStateStore) GetRoomMemberList(roomID id.RoomID) (members []id.UserID, err error) {
|
|
||||||
var rows *sql.Rows
|
|
||||||
rows, err = store.db.Query("SELECT user_id FROM mx_user_profile WHERE room_id=$1", roomID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for rows.Next() {
|
|
||||||
var userID id.UserID
|
|
||||||
err := rows.Scan(&userID)
|
|
||||||
if err != nil {
|
|
||||||
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
|
|
||||||
} else {
|
|
||||||
members = append(members, userID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||||
row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
|
row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
|
||||||
membership := event.MembershipLeave
|
membership := event.MembershipLeave
|
||||||
|
@ -138,8 +122,10 @@ func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*e
|
||||||
|
|
||||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||||
rows, err := store.db.Query(`
|
rows, err := store.db.Query(`
|
||||||
SELECT room_id FROM mx_user_profile WHERE user_id=$2 AND portal.encrypted=true
|
SELECT room_id FROM mx_user_profile
|
||||||
LEFT JOIN portal WHEN portal.mxid=mx_user_profile.room_id`, userID)
|
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||||
|
WHERE user_id=$1 AND portal.encrypted=true
|
||||||
|
`, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
||||||
return
|
return
|
||||||
|
|
74
database/upgrades/2020-05-09-crypto-store.go
Normal file
74
database/upgrades/2020-05-09-crypto-store.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package upgrades
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
|
||||||
|
// TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite
|
||||||
|
_, err := tx.Exec(`CREATE TABLE crypto_account (
|
||||||
|
device_id VARCHAR(255) PRIMARY KEY,
|
||||||
|
shared BOOLEAN NOT NULL,
|
||||||
|
sync_token TEXT NOT NULL,
|
||||||
|
account bytea NOT NULL
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
|
||||||
|
sender_key CHAR(43),
|
||||||
|
session_id VARCHAR(255),
|
||||||
|
index INTEGER,
|
||||||
|
event_id VARCHAR(255) NOT NULL,
|
||||||
|
timestamp BIGINT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY (sender_key, session_id, index)
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
|
||||||
|
user_id VARCHAR(255) PRIMARY KEY
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`CREATE TABLE crypto_device (
|
||||||
|
user_id VARCHAR(255),
|
||||||
|
device_id VARCHAR(255),
|
||||||
|
identity_key CHAR(43) NOT NULL,
|
||||||
|
signing_key CHAR(43) NOT NULL,
|
||||||
|
trust SMALLINT NOT NULL,
|
||||||
|
deleted BOOLEAN NOT NULL,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY (user_id, device_id)
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
|
||||||
|
session_id CHAR(43) PRIMARY KEY,
|
||||||
|
sender_key VARCHAR(255) NOT NULL,
|
||||||
|
session bytea NOT NULL,
|
||||||
|
created_at timestamp NOT NULL,
|
||||||
|
last_used timestamp NOT NULL
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
|
||||||
|
session_id CHAR(43) PRIMARY KEY,
|
||||||
|
sender_key CHAR(43) NOT NULL,
|
||||||
|
signing_key CHAR(43) NOT NULL,
|
||||||
|
room_id VARCHAR(255) NOT NULL,
|
||||||
|
session bytea NOT NULL,
|
||||||
|
forwarding_chains bytea NOT NULL
|
||||||
|
)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}}
|
||||||
|
}
|
|
@ -28,7 +28,7 @@ type upgrade struct {
|
||||||
fn upgradeFunc
|
fn upgradeFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
const NumberOfUpgrades = 13
|
const NumberOfUpgrades = 14
|
||||||
|
|
||||||
var upgrades [NumberOfUpgrades]upgrade
|
var upgrades [NumberOfUpgrades]upgrade
|
||||||
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -15,7 +15,7 @@ require (
|
||||||
gopkg.in/yaml.v2 v2.2.8
|
gopkg.in/yaml.v2 v2.2.8
|
||||||
maunium.net/go/mauflag v1.0.0
|
maunium.net/go/mauflag v1.0.0
|
||||||
maunium.net/go/maulogger/v2 v2.1.1
|
maunium.net/go/maulogger/v2 v2.1.1
|
||||||
maunium.net/go/mautrix v0.4.0
|
maunium.net/go/mautrix v0.4.1
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6
|
replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -86,3 +86,5 @@ maunium.net/go/mautrix v0.3.7 h1:N0czrZeAwjvBrw2a/B2G6U3EwIYaWpt7OuSslGp8DRc=
|
||||||
maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
|
maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
|
||||||
maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4=
|
maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4=
|
||||||
maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
||||||
|
maunium.net/go/mautrix v0.4.1 h1:i2lJNT+TE4AAL3cVKUN4jKVRkujCE/oS8aIsj8+7iNE=
|
||||||
|
maunium.net/go/mautrix v0.4.1/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
||||||
|
|
14
main.go
14
main.go
|
@ -126,6 +126,7 @@ type Crypto interface {
|
||||||
HandleMemberEvent(*event.Event)
|
HandleMemberEvent(*event.Event)
|
||||||
Decrypt(*event.Event) (*event.Event, error)
|
Decrypt(*event.Event) (*event.Event, error)
|
||||||
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
|
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
|
||||||
|
Init() error
|
||||||
Start()
|
Start()
|
||||||
Stop()
|
Stop()
|
||||||
}
|
}
|
||||||
|
@ -225,11 +226,7 @@ func (bridge *Bridge) Init() {
|
||||||
bridge.Log.Debugln("Initializing Matrix event handler")
|
bridge.Log.Debugln("Initializing Matrix event handler")
|
||||||
bridge.MatrixHandler = NewMatrixHandler(bridge)
|
bridge.MatrixHandler = NewMatrixHandler(bridge)
|
||||||
bridge.Formatter = NewFormatter(bridge)
|
bridge.Formatter = NewFormatter(bridge)
|
||||||
err = bridge.initCrypto()
|
bridge.Crypto = NewCryptoHelper(bridge)
|
||||||
if err != nil {
|
|
||||||
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
|
|
||||||
os.Exit(19)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) Start() {
|
func (bridge *Bridge) Start() {
|
||||||
|
@ -238,6 +235,13 @@ func (bridge *Bridge) Start() {
|
||||||
bridge.Log.Fatalln("Failed to initialize database:", err)
|
bridge.Log.Fatalln("Failed to initialize database:", err)
|
||||||
os.Exit(15)
|
os.Exit(15)
|
||||||
}
|
}
|
||||||
|
if bridge.Crypto != nil {
|
||||||
|
err := bridge.Crypto.Init()
|
||||||
|
if err != nil {
|
||||||
|
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
|
||||||
|
os.Exit(19)
|
||||||
|
}
|
||||||
|
}
|
||||||
if bridge.Provisioning != nil {
|
if bridge.Provisioning != nil {
|
||||||
bridge.Log.Debugln("Initializing provisioning API")
|
bridge.Log.Debugln("Initializing provisioning API")
|
||||||
bridge.Provisioning.Init()
|
bridge.Provisioning.Init()
|
||||||
|
|
Loading…
Reference in a new issue