diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index 520ce8d6e..b778acb21 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -36,11 +36,13 @@ type CurrentStateInternalAPI interface { } type QuerySharedUsersRequest struct { - UserID string + UserID string + ExcludeRoomIDs []string + IncludeRoomIDs []string } type QuerySharedUsersResponse struct { - UserIDs []string + UserIDsToCount map[string]int } type QueryRoomsForUserRequest struct { diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go index 4dac742f4..1366a0be8 100644 --- a/currentstateserver/currentstateserver_test.go +++ b/currentstateserver/currentstateserver_test.go @@ -20,7 +20,6 @@ import ( "encoding/json" "net/http" "reflect" - "sort" "testing" "time" @@ -227,13 +226,31 @@ func TestQuerySharedUsers(t *testing.T) { req api.QuerySharedUsersRequest wantRes api.QuerySharedUsersResponse }{ - // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C) + // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A:4,B:2,C:1) { req: api.QuerySharedUsersRequest{ UserID: "@alice:localhost", }, wantRes: api.QuerySharedUsersResponse{ - UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"}, + UserIDsToCount: map[string]int{ + "@alice:localhost": 4, + "@bob:localhost": 2, + "@charlie:localhost": 1, + }, + }, + }, + + // Exclude (A,C): sharing (A,B) (A,B) (A) produces (A:3,B:2) + { + req: api.QuerySharedUsersRequest{ + UserID: "@alice:localhost", + ExcludeRoomIDs: []string{"!foo2:bar"}, + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDsToCount: map[string]int{ + "@alice:localhost": 3, + "@bob:localhost": 2, + }, }, }, @@ -243,7 +260,7 @@ func TestQuerySharedUsers(t *testing.T) { UserID: "@unknownuser:localhost", }, wantRes: api.QuerySharedUsersResponse{ - UserIDs: nil, + UserIDsToCount: map[string]int{}, }, }, @@ -253,7 +270,35 @@ func TestQuerySharedUsers(t *testing.T) { UserID: "@dave:localhost", }, wantRes: api.QuerySharedUsersResponse{ - UserIDs: nil, + UserIDsToCount: map[string]int{}, + }, + }, + + // left real user but with included room returns the included room member + { + req: api.QuerySharedUsersRequest{ + UserID: "@dave:localhost", + IncludeRoomIDs: []string{"!foo:bar"}, + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDsToCount: map[string]int{ + "@alice:localhost": 1, + "@bob:localhost": 1, + }, + }, + }, + + // including a room more than once doesn't double counts + { + req: api.QuerySharedUsersRequest{ + UserID: "@dave:localhost", + IncludeRoomIDs: []string{"!foo:bar", "!foo:bar", "!foo:bar"}, + }, + wantRes: api.QuerySharedUsersResponse{ + UserIDsToCount: map[string]int{ + "@alice:localhost": 1, + "@bob:localhost": 1, + }, }, }, } @@ -266,10 +311,8 @@ func TestQuerySharedUsers(t *testing.T) { t.Errorf("QuerySharedUsers returned error: %s", err) continue } - sort.Strings(res.UserIDs) - sort.Strings(tc.wantRes.UserIDs) - if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) { - t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs) + if !reflect.DeepEqual(res.UserIDsToCount, tc.wantRes.UserIDsToCount) { + t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDsToCount, tc.wantRes.UserIDsToCount) } } } diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index e945d0c11..c581c524c 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -74,10 +74,27 @@ func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api if err != nil { return err } + roomIDs = append(roomIDs, req.IncludeRoomIDs...) + excludeMap := make(map[string]bool) + for _, roomID := range req.ExcludeRoomIDs { + excludeMap[roomID] = true + } + // filter out excluded rooms + j := 0 + for i := range roomIDs { + // move elements to include to the beginning of the slice + // then trim elements on the right + if !excludeMap[roomIDs[i]] { + roomIDs[j] = roomIDs[i] + j++ + } + } + roomIDs = roomIDs[:j] + users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs) if err != nil { return err } - res.UserIDs = users + res.UserIDsToCount = users return nil } diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go index 1c4635be2..8deaa3484 100644 --- a/currentstateserver/storage/interface.go +++ b/currentstateserver/storage/interface.go @@ -37,6 +37,6 @@ type Database interface { GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) // Redact a state event RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error - // JoinedUsersSetInRooms returns all joined users in the rooms given. - JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) + // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. + JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) } diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go index 9e0070f16..294f757cf 100644 --- a/currentstateserver/storage/postgres/current_room_state_table.go +++ b/currentstateserver/storage/postgres/current_room_state_table.go @@ -78,7 +78,8 @@ const selectBulkStateContentWildSQL = "" + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" const selectJoinedUsersSetForRoomsSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + + " type = 'm.room.member' and content_value = 'join' GROUP BY state_key" type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt @@ -124,21 +125,22 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro return s, nil } -func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { +func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") - var userIDs []string + result := make(map[string]int) for rows.Next() { var userID string - if err := rows.Scan(&userID); err != nil { + var count int + if err := rows.Scan(&userID, &count); err != nil { return nil, err } - userIDs = append(userIDs, userID) + result[userID] = count } - return userIDs, rows.Err() + return result, rows.Err() } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index aafb5fdd0..dac38790d 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -86,6 +86,6 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) } -func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) { +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) } diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 4d3803b64..5706fa35c 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -67,7 +67,7 @@ const selectBulkStateContentWildSQL = "" + "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" const selectJoinedUsersSetForRoomsSQL = "" + - "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" type currentRoomStateStatements struct { db *sql.DB @@ -106,7 +106,7 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) return s, nil } -func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { +func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i, v := range roomIDs { iRoomIDs[i] = v @@ -117,15 +117,16 @@ func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Co return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") - var userIDs []string + result := make(map[string]int) for rows.Next() { var userID string - if err := rows.Scan(&userID); err != nil { + var count int + if err := rows.Scan(&userID, &count); err != nil { return nil, err } - userIDs = append(userIDs, userID) + result[userID] = count } - return userIDs, rows.Err() + return result, rows.Err() } // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go index 88e7a31b7..121bf4fdf 100644 --- a/currentstateserver/storage/tables/interface.go +++ b/currentstateserver/storage/tables/interface.go @@ -36,8 +36,9 @@ type CurrentRoomState interface { // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) - // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms. - SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) + // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the + // counts of how many rooms they are joined. + SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) } // StrippedEvent represents a stripped event for returning extracted content values.