0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-05-29 02:33:51 +02:00

Finish inbound E2E device lists (#1243)

* Add tests for device list updates

* Add stale_device_lists table and use db before asking remote for device keys

* Fetch remote keys if all devices are requested

* Add display_name col to store remote device names

Few other tweaks to make `Server correctly handles incoming m.device_list_update`
pass.

* Fix sqlite otk bug

* Unbuffered channel to block /send causing sytest to not race anymore

* Linting and fix bug whereby we didn't send updated dl tokens to the client causing a tightloop on /sync sometimes

* No longer assert staleness as Update blocks on workers now

* Back out tweaks

* Bugfixes
This commit is contained in:
Kegsay 2020-08-07 17:32:13 +01:00 committed by GitHub
parent 30c2325eaf
commit f371783da7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 639 additions and 48 deletions

View file

@ -23,7 +23,6 @@ import (
"time"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -65,7 +64,7 @@ type DeviceListUpdater struct {
mu *sync.Mutex // protects UserIDToMutex
db DeviceListUpdaterDatabase
producer *producers.KeyChange
producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName
}
@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface {
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
}
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
type KeyChangeProducer interface {
ProduceKeyChanges(keys []api.DeviceMessage) error
}
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient,
numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
if err != nil {
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
}
// if this is the first time we're hearing about this user, sync the device list manually.
if len(event.PrevID) == 0 {
exists = false
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"prev_ids_exist": exists,
"user_id": event.UserID,
"device_id": event.DeviceID,
"stream_id": event.StreamID,
"prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
}).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
hasFailures = true
continue
}
err = u.updateDeviceList(ctx, &res)
err = u.updateDeviceList(&res)
if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true
}
}
return hasFailures
}
func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
keys := make([]api.DeviceMessage, len(res.Devices))
for i, device := range res.Devices {
keyJSON, err := json.Marshal(device.Keys)
@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs
}
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
if err != nil {
return err
return fmt.Errorf("failed to store remote device keys: %w", err)
}
return u.db.MarkDeviceListStale(ctx, res.UserID, false)
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
if err != nil {
return fmt.Errorf("failed to mark device list as fresh: %w", err)
}
err = u.producer.ProduceKeyChanges(keys)
if err != nil {
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
}
return nil
}

View file

@ -0,0 +1,242 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"crypto/ed25519"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
var (
ctx = context.Background()
)
type mockKeyChangeProducer struct {
events []api.DeviceMessage
}
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
p.events = append(p.events, keys...)
return nil
}
type mockDeviceListUpdaterDatabase struct {
staleUsers map[string]bool
prevIDsExist func(string, []int) bool
storedKeys []api.DeviceMessage
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
var result []string
for userID := range d.staleUsers {
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return nil, err
}
if len(domains) == 0 {
result = append(result, userID)
continue
}
for _, d := range domains {
if remoteServer == d {
result = append(result, userID)
break
}
}
}
return result, nil
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
d.staleUsers[userID] = isStale
return nil
}
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
d.storedKeys = append(d.storedKeys, keys...)
return nil
}
// PrevIDsExists returns true if all prev IDs exist for this user.
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
return d.prevIDsExist(userID, prevIDs), nil
}
type roundTripper struct {
fn func(*http.Request) (*http.Response, error)
}
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.fn(req)
}
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient(
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey,
)
fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper})
return fedClient
}
// Test that the device keys get persisted and emitted if we have the previous IDs.
func TestUpdateHavePrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return true
},
}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
DeviceID: "FOO",
Keys: []byte(`{"key":"value"}`),
PrevID: []int{0},
StreamID: 1,
UserID: "@alice:localhost",
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
want := api.DeviceMessage{
StreamID: event.StreamID,
DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys,
UserID: event.UserID,
},
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
if db.staleUsers[event.UserID] {
t.Errorf("%s incorrectly marked as stale", event.UserID)
}
}
// Test that device keys are fetched from the remote server if we are missing prev IDs
// and that the user's devices are marked as stale until it succeeds.
func TestUpdateNoPrevID(t *testing.T) {
db := &mockDeviceListUpdaterDatabase{
staleUsers: make(map[string]bool),
prevIDsExist: func(string, []int) bool {
return false
},
}
producer := &mockKeyChangeProducer{}
remoteUserID := "@alice:example.somewhere"
var wg sync.WaitGroup
wg.Add(1)
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
defer wg.Done()
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
}
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`
{
"user_id": "` + remoteUserID + `",
"stream_id": 5,
"devices": [
{
"device_id": "JLAFKJWSCS",
"keys": ` + keyJSON + `,
"device_display_name": "Mobile Phone"
}
]
}
`)),
}, nil
})
updater := NewDeviceListUpdater(db, producer, fedClient, 2)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Mobile Phone",
Deleted: false,
DeviceID: "another_device_id",
Keys: []byte(`{"key":"value"}`),
PrevID: []int{3},
StreamID: 4,
UserID: remoteUserID,
}
err := updater.Update(ctx, event)
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
// At this point we show have this device list marked as stale and not store the keys or emitted anything
if !db.staleUsers[event.UserID] {
t.Errorf("%s not marked as stale", event.UserID)
}
if len(producer.events) > 0 {
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
}
if len(db.storedKeys) > 0 {
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
}
t.Log("waiting for /users/devices to be called...")
wg.Wait()
// wait a bit for db to be updated...
time.Sleep(100 * time.Millisecond)
want := api.DeviceMessage{
StreamID: 5,
DeviceKeys: api.DeviceKeys{
DeviceID: "JLAFKJWSCS",
DisplayName: "Mobile Phone",
UserID: remoteUserID,
KeyJSON: []byte(keyJSON),
},
}
// Now we should have a fresh list and the keys and emitted something
if db.staleUsers[event.UserID] {
t.Errorf("%s still marked as stale", event.UserID)
}
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
}
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
}
}

View file

@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
if len(dk.KeyJSON) == 0 {
continue // don't include blank keys
}
// inject display name if known
// inject display name if known (either locally or remotely)
displayName := dk.DisplayName
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
}
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
}{displayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
}
} else {
@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
}
}
// TODO: set device display names when they are known
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
if len(domainToDeviceKeys) == 0 {
return // nothing to query
}
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
for userID, deviceIDs := range userToDeviceMap {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
if _, ok := fetchRemote[domain]; !ok {
fetchRemote[domain] = make(map[string][]string)
}
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
continue
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, key := range keys {
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
}
}
}
return fetchRemote
}
func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
) {

View file

@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4, $5)" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4, stream_id = $5"
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys {
var keyJSONStr string
var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
}
return nil
}
@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err
@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID
var keyJSON string
var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)

View file

@ -0,0 +1,118 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
if err != nil {
return nil, err
}
sdl, err := NewPostgresStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil
}

View file

@ -26,10 +26,11 @@ import (
)
type Database struct {
DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset,
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
return nil, nil // TODO
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return nil // TODO
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
}

View file

@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4, $5)" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4, stream_id = $5"
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID
var keyJSON string
var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys {
var keyJSONStr string
var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
}
return nil
}
@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err

View file

@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return err
})
if keyJSON == "" {
return nil, nil
}
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err

View file

@ -0,0 +1,118 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
return nil, err
}
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil
}

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
type OneTimeKeys interface {
@ -45,3 +46,8 @@ type KeyChanges interface {
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
}
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
}

View file

@ -46,6 +46,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
// nolint:gocyclo
func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
userID string, res *types.Response, from, to types.StreamingToken,
@ -68,22 +69,20 @@ func DeviceListCatchup(
var partition int32
var offset int64
partition = -1
offset = sarama.OffsetOldest
// Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := from.Log(DeviceListLogName)
if logOffset != nil {
partition = logOffset.Partition
offset = logOffset.Offset
} else {
partition = -1
offset = sarama.OffsetOldest
}
var toOffset int64
toOffset = sarama.OffsetNewest
toLog := to.Log(DeviceListLogName)
if toLog != nil {
if toLog != nil && toLog.Offset > 0 {
toOffset = toLog.Offset
} else {
toOffset = sarama.OffsetNewest
}
var queryRes api.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
@ -96,6 +95,10 @@ func DeviceListCatchup(
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return hasNew, nil
}
util.GetLogger(ctx).Debugf(
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
)
userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true
@ -116,6 +119,13 @@ func DeviceListCatchup(
userSet[userID] = true
}
}
// set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{
Partition: queryRes.Partition,
Offset: queryRes.Offset,
})
res.NextBatch = to.String()
return hasNew, nil
}

View file

@ -112,6 +112,9 @@ type StreamingToken struct {
}
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
if t.logs == nil {
t.logs = make(map[string]*LogPosition)
}
t.logs[name] = lp
}
@ -173,12 +176,14 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
}
ret.Positions[i] = other.Positions[i]
}
ret.logs = make(map[string]*LogPosition)
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
t.logs[name] = otherLog
copy := *otherLog
ret.logs[name] = &copy
}
return ret
}

View file

@ -138,6 +138,7 @@ Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id
#Server correctly handles incoming m.device_list_update
Can add account data
Can add account data to room
Can get account data without syncing