0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-20 11:38:20 +02:00

Return read-only collections from @cached methods (#13755)

It's important that collections returned from `@cached` methods are not
modified, otherwise future retrievals from the cache will return the
modified collection.

This applies to the return values from `@cached` methods and the values
inside the dictionaries returned by `@cachedList` methods. It's not
necessary for the dictionaries returned by `@cachedList` methods
themselves to be read-only.

Signed-off-by: Sean Quah <seanq@matrix.org>
Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
Sean Quah 2023-02-10 23:29:00 +00:00 committed by GitHub
parent 14be78d492
commit d0c713cc85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 98 additions and 77 deletions

1
changelog.d/13755.misc Normal file
View file

@ -0,0 +1 @@
Re-type hint some collections as read-only.

View file

@ -15,7 +15,7 @@ import logging
import math
import resource
import sys
from typing import TYPE_CHECKING, List, Sized, Tuple
from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from prometheus_client import Gauge
@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
current_mau_count_by_service = {}
current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, Collection
from matrix_common.regex import glob_to_regex
@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
return False
def is_publishing_room_allowed(
self, user_id: str, room_id: str, aliases: List[str]
self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool:
"""Checks if the given user is allowed to publish the room
@ -122,7 +122,7 @@ class _RoomDirectoryRule:
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
prev_event_ids: List[str],
prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)

View file

@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
summary: Dict[str, MemberSummary],
summary: Mapping[str, MemberSummary],
) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join

View file

@ -14,7 +14,7 @@
import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal
@ -486,7 +486,7 @@ class DirectoryHandler:
)
if canonical_alias:
# Ensure we do not mutate room_aliases.
room_aliases = room_aliases + [canonical_alias]
room_aliases = list(room_aliases) + [canonical_alias]
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
@ -529,7 +529,7 @@ class DirectoryHandler:
async def get_aliases_for_room(
self, requester: Requester, room_id: str
) -> List[str]:
) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
@staticmethod
def filter_out_private_receipts(
rooms: List[JsonDict], user_id: str
rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)

View file

@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
"local_aliases": aliases_for_room,
"local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}

View file

@ -1519,7 +1519,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
@ -2301,7 +2301,7 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
tags: Optional[Mapping[str, Mapping[str, Any]]],
account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:

View file

@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -149,7 +150,7 @@ class BulkPushRuleEvaluator:
# little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster.
if event.type == EventTypes.Member:
local_users = []
local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries.
@ -184,7 +185,6 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
local_users = list(local_users)
local_users.append(invited)
if not local_users:

View file

@ -226,7 +226,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]:
"""
Get the users IDs who are currently in a room.

View file

@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Callable,
@ -23,7 +24,6 @@ from typing import (
List,
Mapping,
Optional,
Set,
Tuple,
)
@ -527,7 +527,7 @@ class StateStorageController:
)
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@ -584,7 +584,7 @@ class StateStorageController:
async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
) -> Mapping[str, ProfileInfo]:
"""
Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has

View file

@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:

View file

@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.

View file

@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str]
txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
# First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {}
results: Dict[str, Mapping[str, JsonDict]] = {}
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
device_specific_results.setdefault(user_id, {})[device_id] = device
results.update(device_specific_results)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Sequence, Tuple
import attr
@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
)
@cached(max_entries=5000)
async def get_aliases_for_room(self, room_id: str) -> List[str]:
async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},

View file

@ -20,7 +20,9 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> List[str]:
) -> Sequence[str]:
"""Returns the fallback key types that have an unused key.
Args:
@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# The `Optional` comes from the `@cachedList` decorator.
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
result = cast(
Dict[str, Optional[Mapping[str, JsonDict]]],
await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
),
)
return result

View file

@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
async def get_max_depth_of(
self, event_ids: Collection[str]
) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table

View file

@ -21,7 +21,9 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
cast,
)
@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients.
Args:
@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[JsonDict]:
) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]:
) -> Dict[str, Sequence[JsonDict]]:
if not room_ids:
return {}
@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.

View file

@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr
@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:

View file

@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None
@cached()
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
async def get_aggregation_groups_for_event(
self, event_id: str
) -> Sequence[JsonDict]:
raise NotImplementedError()
@cachedList(

View file

@ -24,6 +24,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count
@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
def get_user_in_room_with_profile(
self, room_id: str, user_id: str
) -> Dict[str, ProfileInfo]:
def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
raise NotImplementedError()
@cachedList(
@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
) -> List[RoomsForUser]:
) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(iterable=True)
async def get_local_users_in_room(self, room_id: str) -> List[str]:
async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
user_who_share_room: Set[str] = set()
for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Collection, Dict, List, Tuple
from typing import Collection, Dict, List, Mapping, Tuple
from unpaddedbase64 import encode_base64
@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
)
async def get_event_reference_hashes(
self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]:
) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events.
Args:

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
async def get_tags_for_user(
self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user.
@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, Dict[str, JsonDict]]:
) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version

View file

@ -16,9 +16,9 @@ import logging
import re
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
) -> List[RoomsForUser]:
) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args