diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 97bcc12a5..ff01b1952 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -15,6 +15,8 @@ package federationapi import ( + "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -167,5 +169,16 @@ func NewInternalAPI( if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start presence consumer") } + + var cleanExpiredEDUs func() + cleanExpiredEDUs = func() { + logrus.Infof("Cleaning expired EDUs") + if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil { + logrus.WithError(err).Error("Failed to clean expired EDUs") + } + time.AfterFunc(time.Hour, cleanExpiredEDUs) + } + time.AfterFunc(time.Minute, cleanExpiredEDUs) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index b6edec5da..0d937ffaf 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -127,6 +127,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.destination, // the destination server name receipt, // NIDs from federationapi_queue_json table event.Type, + nil, // this will use the default expireEDUTypes map ); err != nil { logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) return diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 29254948b..b8109b432 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "time" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" @@ -38,7 +39,7 @@ type Database interface { GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string) error + AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error @@ -70,4 +71,6 @@ type Database interface { // Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs)) // such that the combination of all server keys will include all the `optKeyIDs`. GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + // DeleteExpiredEDUs cleans up expired EDUs + DeleteExpiredEDUs(ctx context.Context) error } diff --git a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go new file mode 100644 index 000000000..53a7a025e --- /dev/null +++ b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus ADD COLUMN IF NOT EXISTS expires_at BIGINT NOT NULL DEFAULT 0;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + return nil +} + +func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 1fedf0ef1..d6507e13b 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -19,9 +19,11 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) const queueEDUsSchema = ` @@ -31,7 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -43,8 +47,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx ` const insertQueueEDUSQL = "" + - "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" + + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)" @@ -65,6 +69,12 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + type queueEDUsStatements struct { db *sql.DB insertQueueEDUStmt *sql.Stmt @@ -73,6 +83,8 @@ type queueEDUsStatements struct { selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -81,27 +93,34 @@ func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { } _, err = s.db.Exec(queueEDUsSchema) if err != nil { - return + return s, err } - if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil { - return + + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "federationapi: add expiresat column", + Up: deltas.UpAddexpiresat, + }, + ) + if err := m.Up(context.Background()); err != nil { + return s, err } - if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + + return s, nil +} + +func (s *queueEDUsStatements) Prepare() error { + return sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(s.db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -110,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -117,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -150,7 +171,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -200,3 +221,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 7c2883c1b..6e208d096 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -91,6 +91,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + if err = queueEDUs.Prepare(); err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, ServerName: serverName, diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index 02a23338f..b62e5d9c5 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -20,10 +20,21 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" ) +// defaultExpiry for EDUs if not listed below +var defaultExpiry = time.Hour * 24 + +// defaultExpireEDUTypes contains EDUs which can/should be expired after a given time +// if the target server isn't reachable for some reason. +var defaultExpireEDUTypes = map[string]time.Duration{ + gomatrixserverlib.MTyping: time.Minute, + gomatrixserverlib.MPresence: time.Minute * 10, +} + // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. @@ -32,7 +43,21 @@ func (d *Database) AssociateEDUWithDestination( serverName gomatrixserverlib.ServerName, receipt *Receipt, eduType string, + expireEDUTypes map[string]time.Duration, ) error { + if expireEDUTypes == nil { + expireEDUTypes = defaultExpireEDUTypes + } + expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry)) + if duration, ok := expireEDUTypes[eduType]; ok { + // Keep EDUs for at least x minutes before deleting them + expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration)) + } + // We forcibly set m.direct_to_device events to 0, as we always want them + // to be delivered. (required for E2EE) + if eduType == gomatrixserverlib.MDirectToDevice { + expiresAt = 0 + } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.FederationQueueEDUs.InsertQueueEDU( ctx, // context @@ -40,6 +65,7 @@ func (d *Database) AssociateEDUWithDestination( eduType, // EDU type for coalescing serverName, // destination server name receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ); err != nil { return fmt.Errorf("InsertQueueEDU: %w", err) } @@ -150,3 +176,26 @@ func (d *Database) GetPendingEDUServerNames( ) ([]gomatrixserverlib.ServerName, error) { return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } + +// DeleteExpiredEDUs deletes expired EDUs +func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) + jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) + if err != nil { + return err + } + if len(jsonNIDs) == 0 { + return nil + } + for i := range jsonNIDs { + d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) + } + + if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { + return err + } + + return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) + }) +} diff --git a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go new file mode 100644 index 000000000..c5030163b --- /dev/null +++ b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go @@ -0,0 +1,68 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus RENAME TO federationsender_queue_edus_old;") + if err != nil { + return fmt.Errorf("failed to rename table: %w", err) + } + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + edu_type TEXT NOT NULL, + server_name TEXT NOT NULL, + json_nid BIGINT NOT NULL, + expires_at BIGINT NOT NULL DEFAULT 0 +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +`) + if err != nil { + return fmt.Errorf("failed to create new table: %w", err) + } + _, err = tx.ExecContext(ctx, ` +INSERT + INTO federationsender_queue_edus ( + edu_type, server_name, json_nid, expires_at + ) SELECT edu_type, server_name, json_nid, 0 FROM federationsender_queue_edus_old; +`) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + return nil +} + +func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;") + if err != nil { + return fmt.Errorf("failed to rename table: %w", err) + } + return nil +} diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index f4c84f094..8e7e7901f 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -20,9 +20,11 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) const queueEDUsSchema = ` @@ -32,7 +34,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -44,8 +48,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx ` const insertQueueEDUSQL = "" + - "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" + + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUsSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" @@ -66,13 +70,22 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + type queueEDUsStatements struct { - db *sql.DB - insertQueueEDUStmt *sql.Stmt + db *sql.DB + insertQueueEDUStmt *sql.Stmt + // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -81,24 +94,33 @@ func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { } _, err = db.Exec(queueEDUsSchema) if err != nil { - return + return s, err } - if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { - return + + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "federationapi: add expiresat column", + Up: deltas.UpAddexpiresat, + }, + ) + if err := m.Up(context.Background()); err != nil { + return s, err } - if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + + return s, nil +} + +func (s *queueEDUsStatements) Prepare() error { + return sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(s.db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -107,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -114,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -159,7 +183,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -209,3 +233,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 9594aaec6..c89cb6bea 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -90,6 +90,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + if err = queueEDUs.Prepare(); err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, ServerName: serverName, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go new file mode 100644 index 000000000..7eba2cbee --- /dev/null +++ b/federationapi/storage/storage_test.go @@ -0,0 +1,81 @@ +package storage_test + +import ( + "context" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, func() { + dbClose() + baseClose() + } +} + +func TestExpireEDUs(t *testing.T) { + var expireEDUTypes = map[string]time.Duration{ + gomatrixserverlib.MReceipt: time.Millisecond, + } + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateFederationDatabase(t, dbType) + defer close() + // insert some data + for i := 0; i < 100; i++ { + receipt, err := db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + assert.NoError(t, err) + } + // add data without expiry + receipt, err := db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + assert.NoError(t, err) + + // Delete expired EDUs + err = db.DeleteExpiredEDUs(ctx) + assert.NoError(t, err) + + // verify the data is gone + data, err := db.GetPendingEDUs(ctx, "localhost", 100) + assert.NoError(t, err) + assert.Equal(t, 1, len(data)) + + // check that m.direct_to_device is never expired + receipt, err = db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + assert.NoError(t, err) + + err = db.DeleteExpiredEDUs(ctx) + assert.NoError(t, err) + + // We should get two EDUs, the m.read_marker and the m.direct_to_device + data, err = db.GetPendingEDUs(ctx, "localhost", 100) + assert.NoError(t, err) + assert.Equal(t, 2, len(data)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 19357393d..3c116a1d0 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -34,12 +34,15 @@ type FederationQueuePDUs interface { } type FederationQueueEDUs interface { - InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) + SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) + DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error + Prepare() error } type FederationQueueJSON interface {