diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 5d4eb8848..7be19dad8 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -186,7 +186,8 @@ type PerformServersAliveResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 8231fcf44..6a737d0ad 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") return nil @@ -180,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") return nil diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 0b1816060..63387b9d8 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -78,7 +78,7 @@ func NewInternalAPI( ) api.FederationInternalAPI { cfg := &base.Cfg.FederationAPI - federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches) + federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches, base.Cfg.Global.ServerName) if err != nil { logrus.WithError(err).Panic("failed to connect to federation sender db") } diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index bac813331..31d1a3c4f 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) if err != nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index a36f51528..21a919f6a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 1f6afe37c..2e2c08911 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender @@ -35,7 +36,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { @@ -89,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC } d.Database = shared.Database{ DB: d.db, + ServerName: serverName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index ddd770e2e..160c7f6fa 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( type Database struct { DB *sql.DB + ServerName gomatrixserverlib.ServerName Cache caching.FederationCache Writer sqlutil.Writer FederationQueuePDUs tables.FederationQueuePDUs @@ -102,8 +103,19 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) { - return d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { + servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) + if err != nil { + return nil, err + } + if excludeSelf { + for i, server := range servers { + if server == d.ServerName { + servers = append(servers[:i], servers[i+1:]...) + } + } + } + return servers, nil } // StoreJSON adds a JSON blob into the queue JSON table and returns diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 0fe6df5da..978dd7136 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // Database stores information needed by the federation sender @@ -34,7 +35,7 @@ type Database struct { } // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (*Database, error) { var d Database var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { @@ -88,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationC } d.Database = shared.Database{ DB: d.db, + ServerName: serverName, Cache: cache, Writer: d.writer, FederationJoinedHosts: joinedHosts, diff --git a/federationapi/storage/storage.go b/federationapi/storage/storage.go index 083f0b302..4b52ca206 100644 --- a/federationapi/storage/storage.go +++ b/federationapi/storage/storage.go @@ -24,15 +24,16 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, cache) + return sqlite3.NewDatabase(dbProperties, cache, serverName) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, cache) + return postgres.NewDatabase(dbProperties, cache, serverName) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/federationapi/storage/storage_wasm.go b/federationapi/storage/storage_wasm.go index 455464e7c..09abed63e 100644 --- a/federationapi/storage/storage_wasm.go +++ b/federationapi/storage/storage_wasm.go @@ -20,13 +20,14 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new database -func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache, serverName gomatrixserverlib.ServerName) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, cache) + return sqlite3.NewDatabase(dbProperties, cache, serverName) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: