Store outbound group sessions in database

This commit is contained in:
Tulir Asokan 2020-05-12 23:16:33 +03:00
parent 1c3de877db
commit c9adb3aba3
5 changed files with 66 additions and 22 deletions

View file

@ -22,7 +22,6 @@ import (
"database/sql"
"fmt"
"strings"
"sync"
"github.com/lib/pq"
"github.com/pkg/errors"
@ -44,9 +43,6 @@ type SQLCryptoStore struct {
Account *crypto.OlmAccount
GhostIDFormat string
OGSLock sync.RWMutex
OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
}
var _ crypto.Store = (*SQLCryptoStore)(nil)
@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
log: db.log.Sub("CryptoStore"),
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
DeviceID: deviceID,
OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
}
}
@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
}, nil
}
func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
store.OGSLock.Lock()
store.OutGroupSessions[roomID] = session
store.OGSLock.Unlock()
return nil
func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
return err
}
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
return err
}
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
store.OGSLock.RLock()
defer store.OGSLock.RUnlock()
return store.OutGroupSessions[roomID], nil
var ogs crypto.OutboundGroupSession
var sessionBytes []byte
err := store.db.QueryRow(`
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1`,
roomID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
intOGS := olm.NewBlankOutboundGroupSession()
err = intOGS.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
ogs.Internal = *intOGS
ogs.RoomID = roomID
return &ogs, nil
}
func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
store.OGSLock.Lock()
delete(store.OutGroupSessions, roomID)
store.OGSLock.Unlock()
return nil
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
return err
}
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
queryString[i] = fmt.Sprintf("$%d", i+1)
params[i] = user
}
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
}
if err != nil {
store.log.Warnln("Failed to filter tracked users:", err)

View file

@ -0,0 +1,26 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
// TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -28,7 +28,7 @@ type upgrade struct {
fn upgradeFunc
}
const NumberOfUpgrades = 14
const NumberOfUpgrades = 15
var upgrades [NumberOfUpgrades]upgrade

2
go.mod
View file

@ -15,7 +15,7 @@ require (
gopkg.in/yaml.v2 v2.2.8
maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.1.1
maunium.net/go/mautrix v0.4.3
maunium.net/go/mautrix v0.4.4
)
replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6

2
go.sum
View file

@ -92,3 +92,5 @@ maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw=
maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4=
maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM=
maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=