diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0fea88da6..13065fa6b 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -81,7 +81,7 @@ type Database interface { // Returns a map following the format data[roomID] = []dataTypes // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error - GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) + GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 25bdb1da3..22bb4d7fa 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -57,7 +57,7 @@ const insertAccountDataSQL = "" + " RETURNING id" const selectAccountDataInRangeSQL = "" + - "SELECT room_id, type FROM syncapi_account_data_type" + + "SELECT id, room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + @@ -103,7 +103,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, -) (data map[string][]string, err error) { +) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), @@ -116,11 +116,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - for rows.Next() { - var dataType string - var roomID string + var dataType string + var roomID string + var id types.StreamPosition - if err = rows.Scan(&roomID, &dataType); err != nil { + for rows.Next() { + if err = rows.Scan(&id, &roomID, &dataType); err != nil { return } @@ -129,8 +130,11 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } + if id > pos { + pos = id + } } - return data, rows.Err() + return data, pos, rows.Err() } func (s *accountDataStatements) SelectMaxAccountDataID( diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 3c431db48..69bceb624 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -265,7 +265,7 @@ func (d *Database) DeletePeeks( func (d *Database) GetAccountDataInRange( ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { +) (map[string][]string, types.StreamPosition, error) { return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 71a098177..e0d97ec32 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -43,7 +43,7 @@ const insertAccountDataSQL = "" + // further parameters are added by prepareWithFilters const selectAccountDataInRangeSQL = "" + - "SELECT room_id, type FROM syncapi_account_data_type" + + "SELECT id, room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" const selectMaxAccountDataIDSQL = "" + @@ -95,7 +95,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( userID string, r types.Range, filter *gomatrixserverlib.EventFilter, -) (data map[string][]string, err error) { +) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) stmt, params, err := prepareWithFilters( s.db, nil, selectAccountDataInRangeSQL, @@ -112,11 +112,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - for rows.Next() { - var dataType string - var roomID string + var dataType string + var roomID string + var id types.StreamPosition - if err = rows.Scan(&roomID, &dataType); err != nil { + for rows.Next() { + if err = rows.Scan(&id, &roomID, &dataType); err != nil { return } @@ -125,9 +126,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } + if id > pos { + pos = id + } } - return data, nil + return data, pos, nil } func (s *accountDataStatements) SelectMaxAccountDataID( diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index ac713dd5c..32b1c34ef 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -27,7 +27,7 @@ import ( type AccountData interface { InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. - SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) + SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 094c51485..99cd4a92a 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -43,7 +43,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( To: to, } - dataTypes, err := p.DB.GetAccountDataInRange( + dataTypes, pos, err := p.DB.GetAccountDataInRange( ctx, req.Device.UserID, r, &req.Filter.AccountData, ) if err != nil { @@ -95,5 +95,5 @@ func (p *AccountDataStreamProvider) IncrementalSync( } } - return to + return pos }