mirror of
https://github.com/matrix-org/dendrite
synced 2025-01-19 06:31:54 +01:00
Generate stream IDs for locally uploaded device keys (#1236)
* Breaking: add stream_id to keyserver_device_keys table * Add tests for stream ID generation * Fix whitelist
This commit is contained in:
parent
ffcb6d2ea1
commit
fb56bbf0b7
11 changed files with 265 additions and 84 deletions
|
@ -43,6 +43,13 @@ func (k *KeyError) Error() string {
|
||||||
return k.Err
|
return k.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceMessage represents the message produced into Kafka by the key server.
|
||||||
|
type DeviceMessage struct {
|
||||||
|
DeviceKeys
|
||||||
|
// A monotonically increasing number which represents device changes for this user.
|
||||||
|
StreamID int
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceKeys represents a set of device keys for a single device
|
// DeviceKeys represents a set of device keys for a single device
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
type DeviceKeys struct {
|
type DeviceKeys struct {
|
||||||
|
@ -50,10 +57,20 @@ type DeviceKeys struct {
|
||||||
UserID string
|
UserID string
|
||||||
// The device ID of this device
|
// The device ID of this device
|
||||||
DeviceID string
|
DeviceID string
|
||||||
|
// The device display name
|
||||||
|
DisplayName string
|
||||||
// The raw device key JSON
|
// The raw device key JSON
|
||||||
KeyJSON []byte
|
KeyJSON []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStreamID returns a copy of this device message with the given stream ID
|
||||||
|
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
|
||||||
|
return DeviceMessage{
|
||||||
|
DeviceKeys: *k,
|
||||||
|
StreamID: streamID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OneTimeKeys represents a set of one-time keys for a single device
|
// OneTimeKeys represents a set of one-time keys for a single device
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
type OneTimeKeys struct {
|
type OneTimeKeys struct {
|
||||||
|
|
|
@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||||
a.uploadDeviceKeys(ctx, req, res)
|
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
var keysToStore []api.DeviceKeys
|
var keysToStore []api.DeviceMessage
|
||||||
// assert that the user ID / device ID are not lying for each key
|
// assert that the user ID / device ID are not lying for each key
|
||||||
for _, key := range req.DeviceKeys {
|
for _, key := range req.DeviceKeys {
|
||||||
|
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
continue // ignore invalid users
|
||||||
|
}
|
||||||
|
if serverName != a.ThisServer {
|
||||||
|
continue // ignore remote users
|
||||||
|
}
|
||||||
if len(key.KeyJSON) == 0 {
|
if len(key.KeyJSON) == 0 {
|
||||||
keysToStore = append(keysToStore, key)
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||||
continue // deleted keys don't need sanity checking
|
continue // deleted keys don't need sanity checking
|
||||||
}
|
}
|
||||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||||
keysToStore = append(keysToStore, key)
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
|
||||||
}
|
}
|
||||||
|
|
||||||
// get existing device keys so we can check for changes
|
// get existing device keys so we can check for changes
|
||||||
existingKeys := make([]api.DeviceKeys, len(keysToStore))
|
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
||||||
for i := range keysToStore {
|
for i := range keysToStore {
|
||||||
existingKeys[i] = api.DeviceKeys{
|
existingKeys[i] = api.DeviceMessage{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
UserID: keysToStore[i].UserID,
|
UserID: keysToStore[i].UserID,
|
||||||
DeviceID: keysToStore[i].DeviceID,
|
DeviceID: keysToStore[i].DeviceID,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||||
|
@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
|
err := a.DB.StoreDeviceKeys(ctx, keysToStore)
|
||||||
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := a.emitDeviceKeyChanges(existingKeys, keysToStore)
|
err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error {
|
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
|
||||||
// find keys in new that are not in existing
|
// find keys in new that are not in existing
|
||||||
var keysAdded []api.DeviceKeys
|
var keysAdded []api.DeviceMessage
|
||||||
for _, newKey := range new {
|
for _, newKey := range new {
|
||||||
exists := false
|
exists := false
|
||||||
for _, existingKey := range existing {
|
for _, existingKey := range existing {
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
|
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
var m sarama.ProducerMessage
|
var m sarama.ProducerMessage
|
||||||
|
|
||||||
|
|
|
@ -32,17 +32,18 @@ type Database interface {
|
||||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device).
|
// for this (user, device).
|
||||||
|
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||||
// Returns an error if there was a problem storing the keys.
|
// Returns an error if there was a problem storing the keys.
|
||||||
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
|
|
||||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
|
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
|
||||||
|
-- This means we do not store an unbounded append-only log of device keys, which is not actually
|
||||||
|
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||||
|
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
||||||
" VALUES ($1, $2, $3, $4)" +
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||||
" DO UPDATE SET key_json = $4"
|
" DO UPDATE SET key_json = $4, stream_id = $5"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectMaxStreamForUserSQL = "" +
|
||||||
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
|
var streamID int
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
|
keys[i].StreamID = streamID
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
now := time.Now().Unix()
|
// nullable if there are no results
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
var nullStream sql.NullInt32
|
||||||
|
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if nullStream.Valid {
|
||||||
|
streamID = nullStream.Int32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
var result []api.DeviceKeys
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceKeys
|
var dk api.DeviceMessage
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
|
var streamID int
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
|
dk.StreamID = streamID
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
|
|
@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string
|
||||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
|
// work out the latest stream IDs for each user
|
||||||
|
userIDToStreamID := make(map[string]int)
|
||||||
|
for _, k := range keys {
|
||||||
|
userIDToStreamID[k.UserID] = 0
|
||||||
|
}
|
||||||
|
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
for userID := range userIDToStreamID {
|
||||||
|
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userIDToStreamID[userID] = int(streamID)
|
||||||
|
}
|
||||||
|
// set the stream IDs for each key
|
||||||
|
for i := range keys {
|
||||||
|
k := keys[i]
|
||||||
|
userIDToStreamID[k.UserID]++ // start stream from 1
|
||||||
|
k.StreamID = userIDToStreamID[k.UserID]
|
||||||
|
keys[i] = k
|
||||||
|
}
|
||||||
|
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
)
|
)
|
||||||
|
@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
ts_added_secs BIGINT NOT NULL,
|
ts_added_secs BIGINT NOT NULL,
|
||||||
key_json TEXT NOT NULL,
|
key_json TEXT NOT NULL,
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
-- Clobber based on tuple of user/device.
|
-- Clobber based on tuple of user/device.
|
||||||
UNIQUE (user_id, device_id)
|
UNIQUE (user_id, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const upsertDeviceKeysSQL = "" +
|
const upsertDeviceKeysSQL = "" +
|
||||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
|
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
|
||||||
" VALUES ($1, $2, $3, $4)" +
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
" ON CONFLICT (user_id, device_id)" +
|
" ON CONFLICT (user_id, device_id)" +
|
||||||
" DO UPDATE SET key_json = $4"
|
" DO UPDATE SET key_json = $4, stream_id = $5"
|
||||||
|
|
||||||
const selectDeviceKeysSQL = "" +
|
const selectDeviceKeysSQL = "" +
|
||||||
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||||
|
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectMaxStreamForUserSQL = "" +
|
||||||
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
|
@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||||
var result []api.DeviceKeys
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceKeys
|
var dk api.DeviceMessage
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
var keyJSON string
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
|
var streamID int
|
||||||
|
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
dk.KeyJSON = []byte(keyJSON)
|
||||||
|
dk.StreamID = streamID
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
result = append(result, dk)
|
result = append(result, dk)
|
||||||
|
@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
return result, rows.Err()
|
return result, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
var keyJSONStr string
|
var keyJSONStr string
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
|
var streamID int
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||||
|
keys[i].StreamID = streamID
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
|
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||||
now := time.Now().Unix()
|
// nullable if there are no results
|
||||||
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
|
var nullStream sql.NullInt32
|
||||||
|
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
if nullStream.Valid {
|
||||||
|
streamID = nullStream.Int32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
now := time.Now().Unix()
|
||||||
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
|
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||||
|
// and that they are returned correctly when querying for device keys.
|
||||||
|
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||||
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||||
|
msgs := []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: bob,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 1 as this is a different user
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "another_device",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v1"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 2 as this is a 2nd device key
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[1].StreamID != 1 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||||
|
}
|
||||||
|
if msgs[2].StreamID != 2 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updating a device sets the next stream ID for that user
|
||||||
|
msgs = []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
DeviceKeys: api.DeviceKeys{
|
||||||
|
DeviceID: "AAA",
|
||||||
|
UserID: alice,
|
||||||
|
KeyJSON: []byte(`{"key":"v2"}`),
|
||||||
|
},
|
||||||
|
// StreamID: 3
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||||
|
if msgs[0].StreamID != 3 {
|
||||||
|
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Querying for device keys returns the latest stream IDs
|
||||||
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||||
|
}
|
||||||
|
wantStreamIDs := map[string]int{
|
||||||
|
"AAA": 3,
|
||||||
|
"another_device": 2,
|
||||||
|
}
|
||||||
|
if len(msgs) != len(wantStreamIDs) {
|
||||||
|
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||||
|
}
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||||
|
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -32,9 +32,10 @@ type OneTimeKeys interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
|
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
|
|
|
@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
||||||
defer func() {
|
defer func() {
|
||||||
s.updateOffset(msg)
|
s.updateOffset(msg)
|
||||||
}()
|
}()
|
||||||
var output api.DeviceKeys
|
var output api.DeviceMessage
|
||||||
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
||||||
// If the message was invalid, log it and move on to the next message in the stream
|
// If the message was invalid, log it and move on to the next message in the stream
|
||||||
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")
|
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")
|
||||||
|
|
|
@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync
|
||||||
Sync can be polled for updates
|
Sync can be polled for updates
|
||||||
Sync is woken up for leaves
|
Sync is woken up for leaves
|
||||||
Newly left rooms appear in the leave section of incremental sync
|
Newly left rooms appear in the leave section of incremental sync
|
||||||
|
Rooms can be created with an initial invite list (SYN-205)
|
||||||
We should see our own leave event, even if history_visibility is restricted (SYN-662)
|
We should see our own leave event, even if history_visibility is restricted (SYN-662)
|
||||||
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
|
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
|
||||||
Newly left rooms appear in the leave section of gapped sync
|
Newly left rooms appear in the leave section of gapped sync
|
||||||
|
|
Loading…
Add table
Reference in a new issue