diff --git a/changelog.d/16612.misc b/changelog.d/16612.misc new file mode 100644 index 000000000..93ceaeafc --- /dev/null +++ b/changelog.d/16612.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index ef8590db6..75fe0183f 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -348,8 +348,7 @@ class Porter: backward_chunk = 0 already_ported = 0 else: - forward_chunk = row["forward_rowid"] - backward_chunk = row["backward_rowid"] + forward_chunk, backward_chunk = row if total_to_port is None: already_ported, total_to_port = await self._get_total_count_to_port( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6d680b079..afd8138ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -269,7 +269,7 @@ class RoomCreationHandler: self, requester: Requester, old_room_id: str, - old_room: Dict[str, Any], + old_room: Tuple[bool, str, bool], new_room_id: str, new_version: RoomVersion, tombstone_event: EventBase, @@ -279,7 +279,7 @@ class RoomCreationHandler: Args: requester: the user requesting the upgrade old_room_id: the id of the room to be replaced - old_room: a dict containing room information for the room to be replaced, + old_room: a tuple containing room information for the room to be replaced, as returned by `RoomWorkerStore.get_room`. new_room_id: the id of the replacement room new_version: the version to upgrade the room to @@ -299,7 +299,7 @@ class RoomCreationHandler: await self.store.store_room( room_id=new_room_id, room_creator_user_id=user_id, - is_public=old_room["is_public"], + is_public=old_room[0], room_version=new_version, ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 918eb203e..eddc2af9b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1260,7 +1260,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Add new room to the room directory if the old room was there # Remove old room from the room directory old_room = await self.store.get_room(old_room_id) - if old_room is not None and old_room["is_public"]: + # If the old room exists and is public. + if old_room is not None and old_room[0]: await self.store.set_room_is_public(old_room_id, False) await self.store.set_room_is_public(room_id, True) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 755c59274..812144a12 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1860,7 +1860,8 @@ class PublicRoomListManager: if not room: return False - return room.get("is_public", False) + # The first item is whether the room is public. + return room[0] async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 23a034522..7e40bea8a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -413,8 +413,8 @@ class RoomMembersRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") members = await self.store.get_users_in_room(room_id) @@ -442,8 +442,8 @@ class RoomStateRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - ret = await self.store.get_room(room_id) - if not ret: + room = await self.store.get_room(room_id) + if not room: raise NotFoundError("Room not found") event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index 82944ca71..3534c3c25 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -147,7 +147,7 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - return 200, {"visibility": "public" if room["is_public"] else "private"} + return 200, {"visibility": "public" if room[0] else "private"} class PutBody(RequestBodyModel): visibility: Literal["public", "private"] = "public" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0af050730..eb34de4df 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1597,7 +1597,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", - ) -> Dict[str, Any]: + ) -> Tuple[Any, ...]: ... @overload @@ -1608,7 +1608,7 @@ class DatabasePool: retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: ... async def simple_select_one( @@ -1618,7 +1618,7 @@ class DatabasePool: retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -2127,7 +2127,7 @@ class DatabasePool: keyvalues: Dict[str, Any], retcols: Collection[str], allow_none: bool = False, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Tuple[Any, ...]]: select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table) if keyvalues: @@ -2145,7 +2145,7 @@ class DatabasePool: if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - return dict(zip(retcols, row)) + return row async def simple_delete_one( self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 04d12a876..775abbac7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -255,33 +255,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): A dict containing the device information, or `None` if the device does not exist. """ - return await self.db_pool.simple_select_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name"), - desc="get_device", - allow_none=True, - ) - - async def get_device_opt( - self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: - """Retrieve a device. Only returns devices that are not marked as - hidden. - - Args: - user_id: The ID of the user which owns the device - device_id: The ID of the device to retrieve - Returns: - A dict containing the device information, or None if the device does not exist. - """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", allow_none=True, ) + if row is None: + return None + return {"user_id": row[0], "device_id": row[1], "display_name": row[2]} async def get_devices_by_user( self, user_id: str @@ -1221,9 +1204,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): retcols=["device_id", "device_data"], allow_none=True, ) - return ( - (row["device_id"], json_decoder.decode(row["device_data"])) if row else None - ) + return (row[0], json_decoder.decode(row[1])) if row else None def _store_dehydrated_device_txn( self, @@ -2326,13 +2307,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): `FALSE` have not been converted. """ - row = await self.db_pool.simple_select_one( - table="device_lists_changes_converted_stream_position", - keyvalues={}, - retcols=["stream_id", "room_id"], - desc="get_device_change_last_converted_pos", + return cast( + Tuple[int, str], + await self.db_pool.simple_select_one( + table="device_lists_changes_converted_stream_position", + keyvalues={}, + retcols=["stream_id", "room_id"], + desc="get_device_change_last_converted_pos", + ), ) - return row["stream_id"], row["room_id"] async def set_device_change_last_converted_pos( self, diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index ad904a26a..fae23c340 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -506,19 +506,26 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - result = self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, + row = cast( + Tuple[int, str, str, Optional[int]], + self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, + ), ) - assert result is not None # see comment on `simple_select_one_txn` - result["auth_data"] = db_to_json(result["auth_data"]) - result["version"] = str(result["version"]) - if result["etag"] is None: - result["etag"] = 0 - return result + return { + "auth_data": db_to_json(row[2]), + "version": str(row[0]), + "algorithm": row[1], + "etag": 0 if row[3] is None else row[3], + } return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 3005e2a2c..8cb61eaee 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1266,9 +1266,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if row is None: continue - key_id = row["key_id"] - key_json = row["key_json"] - used = row["used"] + key_id, key_json, used = row # Mark fallback key as used if not already. if not used and mark_as_used: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index f1b099150..7e992ca4a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -193,7 +193,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_ids_chains", @@ -411,7 +412,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Check if we have indexed the room so we can use the chain cover # algorithm. room = await self.get_room(room_id) # type: ignore[attr-defined] - if room["has_auth_chain_index"]: + # If the room has an auth chain index. + if room[1]: try: return await self.db_pool.runInteraction( "get_auth_chain_difference_chains", @@ -1437,24 +1439,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) if event_lookup_result is not None: + event_type, depth, stream_ordering = event_lookup_result logger.debug( "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", room_id, seed_event_id, - event_lookup_result["depth"], - event_lookup_result["stream_ordering"], - event_lookup_result["type"], + depth, + stream_ordering, + event_type, ) - if event_lookup_result["depth"]: - queue.put( - ( - -event_lookup_result["depth"], - -event_lookup_result["stream_ordering"], - seed_event_id, - event_lookup_result["type"], - ) - ) + if depth: + queue.put((-depth, -stream_ordering, seed_event_id, event_type)) while not queue.empty() and len(event_id_results) < limit: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7c34bde3e..5207cc0f4 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1934,8 +1934,7 @@ class PersistEventsStore: if row is None: return - redacted_relates_to = row["relates_to_id"] - rel_type = row["relation_type"] + redacted_relates_to, rel_type = row self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5bf864c1f..4e63a16fa 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1998,7 +1998,7 @@ class EventsWorkerStore(SQLBaseStore): if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - return int(res["topological_ordering"]), int(res["stream_ordering"]) + return int(res[0]), int(res[1]) async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 7f99c64f1..3f80a64dc 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -208,7 +208,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) if row is None: return None - return LocalMedia(media_id=media_id, **row) + return LocalMedia( + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + quarantined_by=row[4], + url_cache=row[5], + last_access_ts=row[6], + safe_from_quarantine=row[7], + ) async def get_local_media_by_user_paginate( self, @@ -541,7 +551,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) if row is None: return row - return RemoteMedia(media_origin=origin, media_id=media_id, **row) + return RemoteMedia( + media_origin=origin, + media_id=media_id, + media_type=row[0], + media_length=row[1], + upload_name=row[2], + created_ts=row[3], + filesystem_id=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + ) async def store_cached_remote_media( self, @@ -665,11 +685,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): if row is None: return None return ThumbnailInfo( - width=row["thumbnail_width"], - height=row["thumbnail_height"], - method=row["thumbnail_method"], - type=row["thumbnail_type"], - length=row["thumbnail_length"], + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] ) @trace diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 3ba9cc885..7ed111f63 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import TYPE_CHECKING, Optional -from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -138,23 +137,18 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: - try: - profile = await self.db_pool.simple_select_one( - table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, - retcols=("displayname", "avatar_url"), - desc="get_profileinfo", - ) - except StoreError as e: - if e.code == 404: - # no match - return ProfileInfo(None, None) - else: - raise - - return ProfileInfo( - avatar_url=profile["avatar_url"], display_name=profile["displayname"] + profile = await self.db_pool.simple_select_one( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + allow_none=True, ) + if profile is None: + # no match + return ProfileInfo(None, None) + + return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 37135d431..f72a23c58 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -468,8 +468,7 @@ class PushRuleStore(PushRulesWorkerStore): "before/after rule not found: %s" % (relative_to_rule,) ) - base_priority_class = res["priority_class"] - base_rule_priority = res["priority"] + base_priority_class, base_rule_priority = res if base_priority_class != priority_class: raise InconsistentRuleException( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 56e8eb16a..3484ce9ef 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -701,8 +701,8 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - stream_ordering = int(res["stream_ordering"]) if res else None - rx_ts = res["received_ts"] if res else 0 + stream_ordering = int(res[0]) if res else None + rx_ts = res[1] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 933d76e90..dec985857 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -425,17 +425,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): account timestamp as milliseconds since the epoch. None if the account has not been renewed using the current token yet. """ - ret_dict = await self.db_pool.simple_select_one( - table="account_validity", - keyvalues={"renewal_token": renewal_token}, - retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], - desc="get_user_from_renewal_token", - ) - - return ( - ret_dict["user_id"], - ret_dict["expiration_ts_ms"], - ret_dict["token_used_ts_ms"], + return cast( + Tuple[str, int, Optional[int]], + await self.db_pool.simple_select_one( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], + desc="get_user_from_renewal_token", + ), ) async def get_renewal_token_for_user(self, user_id: str) -> str: @@ -989,16 +986,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: user id, or None if no user id/threepid mapping exists """ - ret = self.db_pool.simple_select_one_txn( + return self.db_pool.simple_select_one_onecol_txn( txn, "user_threepids", {"medium": medium, "address": address}, - ["user_id"], + "user_id", True, ) - if ret: - return ret["user_id"] - return None async def user_add_threepid( self, @@ -1435,16 +1429,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): if res is None: return False + uses_allowed, pending, completed, expiry_time = res + # Check if the token has expired now = self._clock.time_msec() - if res["expiry_time"] and res["expiry_time"] < now: + if expiry_time and expiry_time < now: return False # Check if the token has been used up - if ( - res["uses_allowed"] - and res["pending"] + res["completed"] >= res["uses_allowed"] - ): + if uses_allowed and pending + completed >= uses_allowed: return False # Otherwise, the token is valid @@ -1490,8 +1483,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - res = cast( - Dict[str, Any], + pending, completed = cast( + Tuple[int, int], self.db_pool.simple_select_one_txn( txn, "registration_tokens", @@ -1506,8 +1499,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "registration_tokens", keyvalues={"token": token}, updatevalues={ - "completed": res["completed"] + 1, - "pending": res["pending"] - 1, + "completed": completed + 1, + "pending": pending - 1, }, ) @@ -1585,13 +1578,22 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): Returns: A dict, or None if token doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], allow_none=True, desc="get_one_registration_token", ) + if row is None: + return None + return { + "token": row[0], + "uses_allowed": row[1], + "pending": row[2], + "completed": row[3], + "expiry_time": row[4], + } async def generate_registration_token( self, length: int, chars: str @@ -1714,7 +1716,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return None # Get all info about the token so it can be sent in the response - return self.db_pool.simple_select_one_txn( + result = self.db_pool.simple_select_one_txn( txn, "registration_tokens", keyvalues={"token": token}, @@ -1728,6 +1730,17 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): allow_none=True, ) + if result is None: + return result + + return { + "token": result[0], + "uses_allowed": result[1], + "pending": result[2], + "completed": result[3], + "expiry_time": result[4], + } + return await self.db_pool.runInteraction( "update_registration_token", _update_registration_token_txn ) @@ -1939,11 +1952,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): keyvalues={"token": token}, updatevalues={"used_ts": ts}, ) - user_id = values["user_id"] - expiry_ts = values["expiry_ts"] - used_ts = values["used_ts"] - auth_provider_id = values["auth_provider_id"] - auth_provider_session_id = values["auth_provider_session_id"] + ( + user_id, + expiry_ts, + used_ts, + auth_provider_id, + auth_provider_session_id, + ) = values # Token was already used if used_ts is not None: @@ -2756,12 +2771,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # reason, the next check is on the client secret, which is NOT NULL, # so we don't have to worry about the client secret matching by # accident. - row = {"client_secret": None, "validated_at": None} + row = None, None else: raise ThreepidValidationError("Unknown session_id") - retrieved_client_secret = row["client_secret"] - validated_at = row["validated_at"] + retrieved_client_secret, validated_at = row row = self.db_pool.simple_select_one_txn( txn, @@ -2775,8 +2789,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): raise ThreepidValidationError( "Validation token not found or has expired" ) - expires = row["expires"] - next_link = row["next_link"] + expires, next_link = row if retrieved_client_secret != client_secret: raise ThreepidValidationError( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index afb880532..ef26d5d9d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -213,21 +213,31 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: + async def get_room(self, room_id: str) -> Optional[Tuple[bool, bool]]: """Retrieve a room. Args: room_id: The ID of the room to retrieve. Returns: - A dict containing the room information, or None if the room is unknown. + A tuple containing the room information: + * True if the room is public + * True if the room has an auth chain index + + or None if the room is unknown. """ - return await self.db_pool.simple_select_one( - table="rooms", - keyvalues={"room_id": room_id}, - retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), - desc="get_room", - allow_none=True, + row = cast( + Optional[Tuple[Optional[Union[int, bool]], Optional[Union[int, bool]]]], + await self.db_pool.simple_select_one( + table="rooms", + keyvalues={"room_id": room_id}, + retcols=("is_public", "has_auth_chain_index"), + desc="get_room", + allow_none=True, + ), ) + if row is None: + return row + return bool(row[0]), bool(row[1]) async def get_room_with_stats(self, room_id: str) -> Optional[RoomStats]: """Retrieve room with statistics. @@ -794,10 +804,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ) if row: - return RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - ) + return RatelimitOverride(messages_per_second=row[0], burst_count=row[1]) else: return None @@ -1371,13 +1378,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): join. """ - result = await self.db_pool.simple_select_one( - table="partial_state_rooms", - keyvalues={"room_id": room_id}, - retcols=("join_event_id", "device_lists_stream_id"), - desc="get_join_event_id_for_partial_state", + return cast( + Tuple[str, int], + await self.db_pool.simple_select_one( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcols=("join_event_id", "device_lists_stream_id"), + desc="get_join_event_id_for_partial_state", + ), ) - return result["join_event_id"], result["device_lists_stream_id"] def get_un_partial_stated_rooms_token(self, instance_name: str) -> int: return self._un_partial_stated_rooms_stream_id_gen.get_current_token_for_writer( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1ed7f2d0e..60d4a9ef3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -559,17 +559,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "non-local user %s" % (user_id,), ) - results_dict = await self.db_pool.simple_select_one( - "local_current_membership", - {"room_id": room_id, "user_id": user_id}, - ("membership", "event_id"), - allow_none=True, - desc="get_local_current_membership_for_user_in_room", + results = cast( + Optional[Tuple[str, str]], + await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ), ) - if not results_dict: + if not results: return None, None - return results_dict.get("membership"), results_dict.get("event_id") + return results @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2225f8272..563c275a2 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1014,9 +1014,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_position_for_event", ) - return PersistedEventPosition( - row["instance_name"] or "master", row["stream_ordering"] - ) + return PersistedEventPosition(row[1] or "master", row[0]) async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event @@ -1033,9 +1031,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return RoomStreamToken( - topological=row["topological_ordering"], stream=row["stream_ordering"] - ) + return RoomStreamToken(topological=row[1], stream=row[0]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -1180,26 +1176,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], + stream_ordering, topological_ordering = cast( + Tuple[int, int], + self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=["stream_ordering", "topological_ordering"], + ), ) - # This cannot happen as `allow_none=False`. - assert results is not None - # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( - topological=results["topological_ordering"] - 1, - stream=results["stream_ordering"], + topological=topological_ordering - 1, stream=stream_ordering ) after_token = RoomStreamToken( - topological=results["topological_ordering"], - stream=results["stream_ordering"], + topological=topological_ordering, stream=stream_ordering ) rows, start_token = self._paginate_room_events_txn( diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py index 5555b5357..64543b4d6 100644 --- a/synapse/storage/databases/main/task_scheduler.py +++ b/synapse/storage/databases/main/task_scheduler.py @@ -183,39 +183,27 @@ class TaskSchedulerWorkerStore(SQLBaseStore): Returns: the task if available, `None` otherwise """ - row = await self.db_pool.simple_select_one( - table="scheduled_tasks", - keyvalues={"id": id}, - retcols=( - "id", - "action", - "status", - "timestamp", - "resource_id", - "params", - "result", - "error", + row = cast( + Optional[ScheduledTaskRow], + await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", ), - allow_none=True, - desc="get_scheduled_task", ) - return ( - TaskSchedulerWorkerStore._convert_row_to_task( - ( - row["id"], - row["action"], - row["status"], - row["timestamp"], - row["resource_id"], - row["params"], - row["result"], - row["error"], - ) - ) - if row - else None - ) + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None async def delete_scheduled_task(self, id: str) -> None: """Delete a specific task from its id. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index fecddb414..2d341affa 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -118,19 +118,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, - retcols=( - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ), + retcols=("response_code", "response_json"), allow_none=True, ) - if result and result["response_code"]: - return result["response_code"], db_to_json(result["response_json"]) + # If the result exists and the response code is non-0. + if result and result[0]: + return result[0], db_to_json(result[1]) else: return None @@ -200,8 +194,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): # check we have a row and retry_last_ts is not null or zero # (retry_last_ts can't be negative) - if result and result["retry_last_ts"]: - return DestinationRetryTimings(**result) + if result and result[1]: + return DestinationRetryTimings( + failure_ts=result[0], retry_last_ts=result[1], retry_interval=result[2] + ) else: return None diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 8ab7c42c4..5b164fed8 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -122,9 +122,13 @@ class UIAuthWorkerStore(SQLBaseStore): desc="get_ui_auth_session", ) - result["clientdict"] = db_to_json(result["clientdict"]) - - return UIAuthSessionData(session_id, **result) + return UIAuthSessionData( + session_id, + clientdict=db_to_json(result[0]), + uri=result[1], + method=result[2], + description=result[3], + ) async def mark_ui_auth_stage_complete( self, @@ -231,18 +235,15 @@ class UIAuthWorkerStore(SQLBaseStore): self, txn: LoggingTransaction, session_id: str, key: str, value: Any ) -> None: # Get the current value. - result = cast( - Dict[str, Any], - self.db_pool.simple_select_one_txn( - txn, - table="ui_auth_sessions", - keyvalues={"session_id": session_id}, - retcols=("serverdict",), - ), + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcol="serverdict", ) # Update it and add it back to the database. - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) serverdict[key] = value self.db_pool.simple_update_one_txn( @@ -265,14 +266,14 @@ class UIAuthWorkerStore(SQLBaseStore): Raises: StoreError if the session cannot be found. """ - result = await self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one_onecol( table="ui_auth_sessions", keyvalues={"session_id": session_id}, - retcols=("serverdict",), + retcol="serverdict", desc="get_ui_auth_session_data", ) - serverdict = db_to_json(result["serverdict"]) + serverdict = db_to_json(result) return serverdict.get(key, default) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a9f5d68b6..1a38f3d78 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -20,7 +20,6 @@ from typing import ( Collection, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -833,13 +832,25 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) - async def _get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: - return await self.db_pool.simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", + async def _get_user_in_directory( + self, user_id: str + ) -> Optional[Tuple[Optional[str], Optional[str]]]: + """ + Fetch the user information in the user directory. + + Returns: + None if the user is unknown, otherwise a tuple of display name and + avatar URL (both of which may be None). + """ + return cast( + Optional[Tuple[Optional[str], Optional[str]]], + await self.db_pool.simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ), ) async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None: diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 76c56d543..15e19b15f 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) - return self.get_success( + row = self.get_success( self.store.db_pool.simple_select_one( table + "_current", {id_col: stat_id}, @@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) + return None if row is None else dict(zip(cols, row)) + def _perform_background_initial_update(self) -> None: # Do the initial population of the stats via the background update self._add_background_updates() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b5f15aa7d..388447eea 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertTrue(profile[0] == display_name) def test_handle_local_profile_change_with_deactivated_user(self) -> None: # create user @@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is in directory profile = self.get_success(self.store._get_user_in_directory(r_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertEqual(profile[0], display_name) # deactivate user self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 37f37a09d..42b065d88 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store._get_user_in_directory(self.other_user)) assert profile is not None - self.assertTrue(profile["display_name"] == "User") + self.assertEqual(profile[0], "User") # Deactivate user channel = self.make_request( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index cffbda9a7..bd59bb50c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # # Note that we don't have the UI Auth session ID, so just pull out the single # row. - ui_auth_data = self.get_success( - self.store.db_pool.simple_select_one( - "ui_auth_sessions", keyvalues={}, retcols=("clientdict",) + result = self.get_success( + self.store.db_pool.simple_select_one_onecol( + "ui_auth_sessions", keyvalues={}, retcol="clientdict" ) ) - client_dict = db_to_json(ui_auth_data["clientdict"]) + client_dict = db_to_json(result) self.assertNotIn("new_password", client_dict) @override_config({"rc_3pid_validation": {"burst_count": 3}}) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index ba4e017a0..b04094b7b 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertLessEqual(det_data.items(), channel.json_body.items()) # Check the `completed` counter has been incremented and pending is 0 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["completed"], 1) - self.assertEqual(res["pending"], 0) + self.assertEqual(completed, 1) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_invalid(self) -> None: @@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1["auth"]["type"] = LoginType.DUMMY self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["pending"], 0) - self.assertEqual(res["completed"], 1) + self.assertEqual(pending, 0) + self.assertEqual(completed, 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, params2) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f34b6b2dc..491e6d5e6 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual((1, 2, 3), ret) self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertFalse(ret) + self.assertIsNone(ret) @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ce34195a2..d3ffe963d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room(self) -> None: - res = self.get_success(self.store.get_room(self.room.to_string())) - assert res is not None - self.assertLessEqual( - { - "room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True, - }.items(), - res.items(), - ) + room = self.get_success(self.store.get_room(self.room.to_string())) + assert room is not None + self.assertTrue(room[0]) def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))