forked from MirrorHub/mautrix-whatsapp
Store outbound group sessions in database
This commit is contained in:
parent
1c3de877db
commit
c9adb3aba3
5 changed files with 66 additions and 22 deletions
|
@ -22,7 +22,6 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -44,9 +43,6 @@ type SQLCryptoStore struct {
|
|||
Account *crypto.OlmAccount
|
||||
|
||||
GhostIDFormat string
|
||||
|
||||
OGSLock sync.RWMutex
|
||||
OutGroupSessions map[id.RoomID]*crypto.OutboundGroupSession
|
||||
}
|
||||
|
||||
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||
|
@ -57,8 +53,6 @@ func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
|
|||
log: db.log.Sub("CryptoStore"),
|
||||
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
|
||||
DeviceID: deviceID,
|
||||
|
||||
OutGroupSessions: make(map[id.RoomID]*crypto.OutboundGroupSession),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,24 +249,46 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutOutboundGroupSession(roomID id.RoomID, session *crypto.OutboundGroupSession) error {
|
||||
store.OGSLock.Lock()
|
||||
store.OutGroupSessions[roomID] = session
|
||||
store.OGSLock.Unlock()
|
||||
return nil
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.db.Exec("INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
||||
session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
|
||||
sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
|
||||
store.OGSLock.RLock()
|
||||
defer store.OGSLock.RUnlock()
|
||||
return store.OutGroupSessions[roomID], nil
|
||||
var ogs crypto.OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
err := store.db.QueryRow(`
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1`,
|
||||
roomID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
intOGS := olm.NewBlankOutboundGroupSession()
|
||||
err = intOGS.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ogs.Internal = *intOGS
|
||||
ogs.RoomID = roomID
|
||||
return &ogs, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PopOutboundGroupSession(roomID id.RoomID) error {
|
||||
store.OGSLock.Lock()
|
||||
delete(store.OutGroupSessions, roomID)
|
||||
store.OGSLock.Unlock()
|
||||
return nil
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
|
||||
|
@ -389,7 +405,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
|
|||
queryString[i] = fmt.Sprintf("$%d", i+1)
|
||||
params[i] = user
|
||||
}
|
||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN (" + strings.Join(queryString, ",") + ")", params...)
|
||||
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
}
|
||||
if err != nil {
|
||||
store.log.Warnln("Failed to filter tracked users:", err)
|
||||
|
|
26
database/upgrades/2020-05-12-outbound-group-session-store.go
Normal file
26
database/upgrades/2020-05-12-outbound-group-session-store.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package upgrades
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
|
||||
// TODO use DATETIME instead of timestamp and BLOB instead of bytea for sqlite
|
||||
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
|
||||
room_id VARCHAR(255) PRIMARY KEY,
|
||||
session_id CHAR(43) NOT NULL UNIQUE,
|
||||
session bytea NOT NULL,
|
||||
shared BOOLEAN NOT NULL,
|
||||
max_messages INTEGER NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
max_age BIGINT NOT NULL,
|
||||
created_at timestamp NOT NULL,
|
||||
last_used timestamp NOT NULL
|
||||
)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
}
|
|
@ -28,7 +28,7 @@ type upgrade struct {
|
|||
fn upgradeFunc
|
||||
}
|
||||
|
||||
const NumberOfUpgrades = 14
|
||||
const NumberOfUpgrades = 15
|
||||
|
||||
var upgrades [NumberOfUpgrades]upgrade
|
||||
|
||||
|
|
2
go.mod
2
go.mod
|
@ -15,7 +15,7 @@ require (
|
|||
gopkg.in/yaml.v2 v2.2.8
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
maunium.net/go/maulogger/v2 v2.1.1
|
||||
maunium.net/go/mautrix v0.4.3
|
||||
maunium.net/go/mautrix v0.4.4
|
||||
)
|
||||
|
||||
replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6
|
||||
|
|
2
go.sum
2
go.sum
|
@ -92,3 +92,5 @@ maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw=
|
|||
maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
||||
maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4=
|
||||
maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
||||
maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM=
|
||||
maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
|
||||
|
|
Loading…
Reference in a new issue