mirror of
https://github.com/matrix-org/dendrite
synced 2024-12-14 06:33:50 +01:00
Remodel how device list change IDs are created (#2098)
* Remodel how device list change IDs are created Previously we made them using the offset Kafka supplied. We don't run Kafka anymore, so now we make the SQL table assign the change ID via an AUTOINCREMENTing ID. Redesign the `keyserver_key_changes` table to have `UNIQUE(user_id)` so we don't accumulate key changes forevermore, we now have at most 1 row per user which contains the highest change ID. This needs a SQL migration. * Ensure we bump the change ID on sqlite * Actually read the DeviceChangeID not the Offset in synapi * Add SQL migrations * Prepare after migration; fixup dendrite-upgrade-test logging * Use higher version numbers; fix sqlite query to increment better * Default 0 on postgres * fixup postgres migration on fresh dendrite instances
This commit is contained in:
parent
db7d9cba8a
commit
2c581377a5
16 changed files with 336 additions and 158 deletions
|
@ -189,7 +189,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
|
||||||
if err := decoder.Decode(&dl); err != nil {
|
if err := decoder.Decode(&dl); err != nil {
|
||||||
return "", fmt.Errorf("failed to decode build image output line: %w", err)
|
return "", fmt.Errorf("failed to decode build image output line: %w", err)
|
||||||
}
|
}
|
||||||
log.Printf("%s: %s", branchOrTagName, dl.Stream)
|
if len(strings.TrimSpace(dl.Stream)) > 0 {
|
||||||
|
log.Printf("%s: %s", branchOrTagName, dl.Stream)
|
||||||
|
}
|
||||||
if dl.Aux != nil {
|
if dl.Aux != nil {
|
||||||
imgID, ok := dl.Aux["ID"]
|
imgID, ok := dl.Aux["ID"]
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -425,8 +427,10 @@ func cleanup(dockerClient *client.Client) {
|
||||||
// ignore all errors, we are just cleaning up and don't want to fail just because we fail to cleanup
|
// ignore all errors, we are just cleaning up and don't want to fail just because we fail to cleanup
|
||||||
containers, _ := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{
|
containers, _ := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{
|
||||||
Filters: label(dendriteUpgradeTestLabel),
|
Filters: label(dendriteUpgradeTestLabel),
|
||||||
|
All: true,
|
||||||
})
|
})
|
||||||
for _, c := range containers {
|
for _, c := range containers {
|
||||||
|
log.Printf("Removing container: %v %v\n", c.ID, c.Names)
|
||||||
s := time.Second
|
s := time.Second
|
||||||
_ = dockerClient.ContainerStop(context.Background(), c.ID, &s)
|
_ = dockerClient.ContainerStop(context.Background(), c.ID, &s)
|
||||||
_ = dockerClient.ContainerRemove(context.Background(), c.ID, types.ContainerRemoveOptions{
|
_ = dockerClient.ContainerRemove(context.Background(), c.ID, types.ContainerRemoveOptions{
|
||||||
|
|
|
@ -69,7 +69,8 @@ type DeviceMessage struct {
|
||||||
*DeviceKeys `json:"DeviceKeys,omitempty"`
|
*DeviceKeys `json:"DeviceKeys,omitempty"`
|
||||||
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
||||||
// A monotonically increasing number which represents device changes for this user.
|
// A monotonically increasing number which represents device changes for this user.
|
||||||
StreamID int
|
StreamID int
|
||||||
|
DeviceChangeID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceKeys represents a set of device keys for a single device
|
// DeviceKeys represents a set of device keys for a single device
|
||||||
|
|
|
@ -59,8 +59,7 @@ func (a *KeyInternalAPI) InputDeviceListUpdate(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||||
partition := 0
|
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||||
userIDs, latest, err := a.DB.KeyChanges(ctx, int32(partition), req.Offset, req.ToOffset)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: err.Error(),
|
Err: err.Error(),
|
||||||
|
|
|
@ -40,16 +40,16 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
||||||
) api.KeyInternalAPI {
|
) api.KeyInternalAPI {
|
||||||
_, consumer, producer := jetstream.Prepare(&cfg.Matrix.JetStream)
|
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||||
|
|
||||||
db, err := storage.NewDatabase(&cfg.Database)
|
db, err := storage.NewDatabase(&cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to key server database")
|
logrus.WithError(err).Panicf("failed to connect to key server database")
|
||||||
}
|
}
|
||||||
keyChangeProducer := &producers.KeyChange{
|
keyChangeProducer := &producers.KeyChange{
|
||||||
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
||||||
Producer: producer,
|
JetStream: js,
|
||||||
DB: db,
|
DB: db,
|
||||||
}
|
}
|
||||||
ap := &internal.KeyInternalAPI{
|
ap := &internal.KeyInternalAPI{
|
||||||
DB: db,
|
DB: db,
|
||||||
|
|
|
@ -18,43 +18,47 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
|
||||||
eduapi "github.com/matrix-org/dendrite/eduserver/api"
|
eduapi "github.com/matrix-org/dendrite/eduserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeyChange produces key change events for the sync API and federation sender to consume
|
// KeyChange produces key change events for the sync API and federation sender to consume
|
||||||
type KeyChange struct {
|
type KeyChange struct {
|
||||||
Topic string
|
Topic string
|
||||||
Producer sarama.SyncProducer
|
JetStream nats.JetStreamContext
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||||
userToDeviceCount := make(map[string]int)
|
userToDeviceCount := make(map[string]int)
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
var m sarama.ProducerMessage
|
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key.DeviceChangeID = id
|
||||||
value, err := json.Marshal(key)
|
value, err := json.Marshal(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Topic = string(p.Topic)
|
m := &nats.Msg{
|
||||||
m.Key = sarama.StringEncoder(key.UserID)
|
Subject: p.Topic,
|
||||||
m.Value = sarama.ByteEncoder(value)
|
Header: nats.Header{},
|
||||||
|
}
|
||||||
|
m.Header.Set(jetstream.UserID, key.UserID)
|
||||||
|
m.Data = value
|
||||||
|
|
||||||
partition, offset, err := p.Producer.SendMessage(&m)
|
_, err = p.JetStream.PublishMsg(m)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
userToDeviceCount[key.UserID]++
|
userToDeviceCount[key.UserID]++
|
||||||
}
|
}
|
||||||
for userID, count := range userToDeviceCount {
|
for userID, count := range userToDeviceCount {
|
||||||
|
@ -67,7 +71,6 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error {
|
func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error {
|
||||||
var m sarama.ProducerMessage
|
|
||||||
output := &api.DeviceMessage{
|
output := &api.DeviceMessage{
|
||||||
Type: api.TypeCrossSigningUpdate,
|
Type: api.TypeCrossSigningUpdate,
|
||||||
OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{
|
OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{
|
||||||
|
@ -75,20 +78,25 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
output.DeviceChangeID = id
|
||||||
|
|
||||||
value, err := json.Marshal(output)
|
value, err := json.Marshal(output)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Topic = string(p.Topic)
|
m := &nats.Msg{
|
||||||
m.Key = sarama.StringEncoder(key.UserID)
|
Subject: p.Topic,
|
||||||
m.Value = sarama.ByteEncoder(value)
|
Header: nats.Header{},
|
||||||
|
|
||||||
partition, offset, err := p.Producer.SendMessage(&m)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
m.Header.Set(jetstream.UserID, key.UserID)
|
||||||
|
m.Data = value
|
||||||
|
|
||||||
|
_, err = p.JetStream.PublishMsg(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,14 +66,14 @@ type Database interface {
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||||
|
|
||||||
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
|
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||||
// their keys in some way.
|
// `userID` is the the user who has changed their keys in some way.
|
||||||
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||||
|
|
||||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||||
// A to offset of sarama.OffsetNewest means no upper limit.
|
// A to offset of sarama.OffsetNewest means no upper limit.
|
||||||
// Returns the offset of the latest key change.
|
// Returns the offset of the latest key change.
|
||||||
KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
|
|
||||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
// If no domains are given, all user IDs with stale device lists are returned.
|
// If no domains are given, all user IDs with stale device lists are returned.
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/pressly/goose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadFromGoose() {
|
||||||
|
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpRefactorKeyChanges(tx *sql.Tx) error {
|
||||||
|
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
|
||||||
|
// even are entries in this table to know if we can query for log_offset. Without the count then
|
||||||
|
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
|
||||||
|
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
|
||||||
|
var count int
|
||||||
|
_ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
|
||||||
|
if count > 0 {
|
||||||
|
var maxOffset int64
|
||||||
|
_ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
|
||||||
|
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
|
||||||
|
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
-- make the new table
|
||||||
|
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownRefactorKeyChanges(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||||
|
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
|
||||||
|
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
log_offset BIGINT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -26,27 +26,25 @@ import (
|
||||||
|
|
||||||
var keyChangesSchema = `
|
var keyChangesSchema = `
|
||||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
partition BIGINT NOT NULL,
|
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||||
log_offset BIGINT NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
// have changed, hence we can just keep bumping the change ID for this user.
|
||||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
|
||||||
const upsertKeyChangeSQL = "" +
|
const upsertKeyChangeSQL = "" +
|
||||||
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
|
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||||
" VALUES ($1, $2, $3)" +
|
" VALUES ($1)" +
|
||||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
|
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||||
" DO UPDATE SET user_id = $3"
|
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||||
|
" RETURNING change_id"
|
||||||
|
|
||||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
|
||||||
// take the max offset value as the latest offset.
|
|
||||||
const selectKeyChangesSQL = "" +
|
const selectKeyChangesSQL = "" +
|
||||||
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 AND log_offset <= $3 GROUP BY user_id"
|
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -59,31 +57,32 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(keyChangesSchema)
|
_, err := db.Exec(keyChangesSchema)
|
||||||
if err != nil {
|
return s, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
func (s *keyChangesStatements) Prepare() (err error) {
|
||||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||||
|
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) SelectKeyChanges(
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
ctx context.Context, fromOffset, toOffset int64,
|
||||||
) (userIDs []string, latestOffset int64, err error) {
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
if toOffset == sarama.OffsetNewest {
|
if toOffset == sarama.OffsetNewest {
|
||||||
toOffset = math.MaxInt64
|
toOffset = math.MaxInt64
|
||||||
}
|
}
|
||||||
latestOffset = fromOffset
|
latestOffset = fromOffset
|
||||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
@ -51,6 +52,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
m := sqlutil.NewMigrations()
|
||||||
|
deltas.LoadRefactorKeyChanges(m)
|
||||||
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = kc.Prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
d := &shared.Database{
|
d := &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Writer: sqlutil.NewDummyWriter(),
|
Writer: sqlutil.NewDummyWriter(),
|
||||||
|
|
|
@ -135,14 +135,16 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||||
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
|
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||||
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset)
|
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/pressly/goose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoadFromGoose() {
|
||||||
|
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
|
||||||
|
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpRefactorKeyChanges(tx *sql.Tx) error {
|
||||||
|
// start counting from the last max offset, else 0.
|
||||||
|
var maxOffset int64
|
||||||
|
var userID string
|
||||||
|
_ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
|
||||||
|
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
-- make the new table
|
||||||
|
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
-- The key owner
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
UNIQUE (user_id)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||||
|
}
|
||||||
|
// to start counting from maxOffset, insert a row with that value
|
||||||
|
if userID != "" {
|
||||||
|
_, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownRefactorKeyChanges(tx *sql.Tx) error {
|
||||||
|
_, err := tx.Exec(`
|
||||||
|
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||||
|
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
offset BIGINT NOT NULL,
|
||||||
|
-- The key owner
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
UNIQUE (partition, offset)
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -27,27 +27,22 @@ import (
|
||||||
var keyChangesSchema = `
|
var keyChangesSchema = `
|
||||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
partition BIGINT NOT NULL,
|
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
offset BIGINT NOT NULL,
|
|
||||||
-- The key owner
|
-- The key owner
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
UNIQUE (partition, offset)
|
UNIQUE (user_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
// have changed, hence we can just keep bumping the change ID for this user.
|
||||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
|
||||||
const upsertKeyChangeSQL = "" +
|
const upsertKeyChangeSQL = "" +
|
||||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
"INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
|
||||||
" VALUES ($1, $2, $3)" +
|
" VALUES ($1)" +
|
||||||
" ON CONFLICT (partition, offset)" +
|
" RETURNING change_id"
|
||||||
" DO UPDATE SET user_id = $3"
|
|
||||||
|
|
||||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
|
||||||
// take the max offset value as the latest offset.
|
|
||||||
const selectKeyChangesSQL = "" +
|
const selectKeyChangesSQL = "" +
|
||||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||||
|
|
||||||
type keyChangesStatements struct {
|
type keyChangesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -60,31 +55,32 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
_, err := db.Exec(keyChangesSchema)
|
_, err := db.Exec(keyChangesSchema)
|
||||||
if err != nil {
|
return s, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
func (s *keyChangesStatements) Prepare() (err error) {
|
||||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||||
|
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyChangesStatements) SelectKeyChanges(
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
ctx context.Context, fromOffset, toOffset int64,
|
||||||
) (userIDs []string, latestOffset int64, err error) {
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
if toOffset == sarama.OffsetNewest {
|
if toOffset == sarama.OffsetNewest {
|
||||||
toOffset = math.MaxInt64
|
toOffset = math.MaxInt64
|
||||||
}
|
}
|
||||||
latestOffset = fromOffset
|
latestOffset = fromOffset
|
||||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,6 +50,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m := sqlutil.NewMigrations()
|
||||||
|
deltas.LoadRefactorKeyChanges(m)
|
||||||
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = kc.Prepare(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
d := &shared.Database{
|
d := &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
Writer: sqlutil.NewExclusiveWriter(),
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
|
|
@ -44,15 +44,18 @@ func MustNotError(t *testing.T, err error) {
|
||||||
func TestKeyChanges(t *testing.T) {
|
func TestKeyChanges(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
db, clean := MustCreateDatabase(t)
|
||||||
defer clean()
|
defer clean()
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
MustNotError(t, err)
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest)
|
MustNotError(t, err)
|
||||||
|
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||||
|
MustNotError(t, err)
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
}
|
}
|
||||||
if latest != 2 {
|
if latest != deviceChangeIDC {
|
||||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
@ -62,15 +65,21 @@ func TestKeyChanges(t *testing.T) {
|
||||||
func TestKeyChangesNoDupes(t *testing.T) {
|
func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
db, clean := MustCreateDatabase(t)
|
||||||
defer clean()
|
defer clean()
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
MustNotError(t, err)
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest)
|
MustNotError(t, err)
|
||||||
|
if deviceChangeIDA == deviceChangeIDB {
|
||||||
|
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||||
|
}
|
||||||
|
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
|
MustNotError(t, err)
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
}
|
}
|
||||||
if latest != 2 {
|
if latest != deviceChangeID {
|
||||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
@ -80,15 +89,18 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||||
db, clean := MustCreateDatabase(t)
|
db, clean := MustCreateDatabase(t)
|
||||||
defer clean()
|
defer clean()
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
MustNotError(t, err)
|
||||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1)
|
MustNotError(t, err)
|
||||||
|
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||||
|
MustNotError(t, err)
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
}
|
}
|
||||||
if latest != 1 {
|
if latest != deviceChangeIDB {
|
||||||
t.Fatalf("KeyChanges: got latest=%d want 1", latest)
|
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
|
|
@ -44,10 +44,12 @@ type DeviceKeys interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
||||||
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
|
|
||||||
|
Prepare() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type StaleDeviceLists interface {
|
type StaleDeviceLists interface {
|
||||||
|
|
|
@ -17,7 +17,6 @@ package consumers
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
@ -34,16 +33,14 @@ import (
|
||||||
|
|
||||||
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
||||||
type OutputKeyChangeEventConsumer struct {
|
type OutputKeyChangeEventConsumer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
keyChangeConsumer *internal.ContinualConsumer
|
keyChangeConsumer *internal.ContinualConsumer
|
||||||
db storage.Database
|
db storage.Database
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
stream types.StreamProvider
|
stream types.StreamProvider
|
||||||
serverName gomatrixserverlib.ServerName // our server name
|
serverName gomatrixserverlib.ServerName // our server name
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||||
keyAPI api.KeyInternalAPI
|
keyAPI api.KeyInternalAPI
|
||||||
partitionToOffset map[int32]int64
|
|
||||||
partitionToOffsetMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
||||||
|
@ -69,16 +66,14 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &OutputKeyChangeEventConsumer{
|
s := &OutputKeyChangeEventConsumer{
|
||||||
ctx: process.Context(),
|
ctx: process.Context(),
|
||||||
keyChangeConsumer: &consumer,
|
keyChangeConsumer: &consumer,
|
||||||
db: store,
|
db: store,
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
keyAPI: keyAPI,
|
keyAPI: keyAPI,
|
||||||
rsAPI: rsAPI,
|
rsAPI: rsAPI,
|
||||||
partitionToOffset: make(map[int32]int64),
|
notifier: notifier,
|
||||||
partitionToOffsetMu: sync.Mutex{},
|
stream: stream,
|
||||||
notifier: notifier,
|
|
||||||
stream: stream,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
consumer.ProcessMessage = s.onMessage
|
consumer.ProcessMessage = s.onMessage
|
||||||
|
@ -88,24 +83,10 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
|
|
||||||
// Start consuming from the key server
|
// Start consuming from the key server
|
||||||
func (s *OutputKeyChangeEventConsumer) Start() error {
|
func (s *OutputKeyChangeEventConsumer) Start() error {
|
||||||
offsets, err := s.keyChangeConsumer.StartOffsets()
|
return s.keyChangeConsumer.Start()
|
||||||
s.partitionToOffsetMu.Lock()
|
|
||||||
for _, o := range offsets {
|
|
||||||
s.partitionToOffset[o.Partition] = o.Offset
|
|
||||||
}
|
|
||||||
s.partitionToOffsetMu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) {
|
|
||||||
s.partitionToOffsetMu.Lock()
|
|
||||||
defer s.partitionToOffsetMu.Unlock()
|
|
||||||
s.partitionToOffset[msg.Partition] = msg.Offset
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
defer s.updateOffset(msg)
|
|
||||||
|
|
||||||
var m api.DeviceMessage
|
var m api.DeviceMessage
|
||||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||||
|
@ -118,15 +99,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
||||||
}
|
}
|
||||||
switch m.Type {
|
switch m.Type {
|
||||||
case api.TypeCrossSigningUpdate:
|
case api.TypeCrossSigningUpdate:
|
||||||
return s.onCrossSigningMessage(m, msg.Offset)
|
return s.onCrossSigningMessage(m, m.DeviceChangeID)
|
||||||
case api.TypeDeviceKeyUpdate:
|
case api.TypeDeviceKeyUpdate:
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
return s.onDeviceKeyMessage(m, msg.Offset)
|
return s.onDeviceKeyMessage(m, m.DeviceChangeID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64) error {
|
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||||
if m.DeviceKeys == nil {
|
if m.DeviceKeys == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -143,7 +124,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
|
||||||
}
|
}
|
||||||
// make sure we get our own key updates too!
|
// make sure we get our own key updates too!
|
||||||
queryRes.UserIDsToCount[output.UserID] = 1
|
queryRes.UserIDsToCount[output.UserID] = 1
|
||||||
posUpdate := types.StreamPosition(offset)
|
posUpdate := types.StreamPosition(deviceChangeID)
|
||||||
|
|
||||||
s.stream.Advance(posUpdate)
|
s.stream.Advance(posUpdate)
|
||||||
for userID := range queryRes.UserIDsToCount {
|
for userID := range queryRes.UserIDsToCount {
|
||||||
|
@ -153,7 +134,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64) error {
|
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||||
output := m.CrossSigningKeyUpdate
|
output := m.CrossSigningKeyUpdate
|
||||||
// work out who we need to notify about the new key
|
// work out who we need to notify about the new key
|
||||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||||
|
@ -167,7 +148,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
||||||
}
|
}
|
||||||
// make sure we get our own key updates too!
|
// make sure we get our own key updates too!
|
||||||
queryRes.UserIDsToCount[output.UserID] = 1
|
queryRes.UserIDsToCount[output.UserID] = 1
|
||||||
posUpdate := types.StreamPosition(offset)
|
posUpdate := types.StreamPosition(deviceChangeID)
|
||||||
|
|
||||||
s.stream.Advance(posUpdate)
|
s.stream.Advance(posUpdate)
|
||||||
for userID := range queryRes.UserIDsToCount {
|
for userID := range queryRes.UserIDsToCount {
|
||||||
|
|
Loading…
Reference in a new issue