0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-05-18 11:33:45 +02:00

Bump black from 23.10.1 to 24.2.0 (#16936)

This commit is contained in:
dependabot[bot] 2024-03-13 16:46:44 +00:00 committed by GitHub
parent 2bdf6280f6
commit 1e68b56a62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
74 changed files with 407 additions and 509 deletions

44
poetry.lock generated
View file

@ -169,29 +169,33 @@ lxml = ["lxml"]
[[package]] [[package]]
name = "black" name = "black"
version = "23.10.1" version = "24.2.0"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"}, {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"},
{file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"}, {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"},
{file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"}, {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"},
{file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"}, {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"},
{file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"}, {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"},
{file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"}, {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"},
{file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"}, {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"},
{file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"}, {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"},
{file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"}, {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"},
{file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"}, {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"},
{file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"}, {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"},
{file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"}, {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"},
{file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"}, {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"},
{file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"}, {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"},
{file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"}, {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"},
{file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"}, {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"},
{file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"}, {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"},
{file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"}, {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"},
{file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"},
{file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"},
{file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"},
{file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"},
] ]
[package.dependencies] [package.dependencies]
@ -205,7 +209,7 @@ typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
colorama = ["colorama (>=0.4.3)"] colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"] d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"] uvloop = ["uvloop (>=0.15.2)"]

View file

@ -1040,10 +1040,10 @@ class Porter:
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self) -> None: async def _setup_state_group_id_seq(self) -> None:
curr_id: Optional[ curr_id: Optional[int] = (
int await self.sqlite_store.db_pool.simple_select_one_onecol(
] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True )
) )
if not curr_id: if not curr_id:
@ -1132,13 +1132,13 @@ class Porter:
) )
async def _setup_auth_chain_sequence(self) -> None: async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id: Optional[ curr_chain_id: Optional[int] = (
int await self.sqlite_store.db_pool.simple_select_one_onecol(
] = await self.sqlite_store.db_pool.simple_select_one_onecol( table="event_auth_chains",
table="event_auth_chains", keyvalues={},
keyvalues={}, retcol="MAX(chain_id)",
retcol="MAX(chain_id)", allow_none=True,
allow_none=True, )
) )
def r(txn: LoggingTransaction) -> None: def r(txn: LoggingTransaction) -> None:

View file

@ -43,7 +43,6 @@ MAIN_TIMELINE: Final = "main"
class Membership: class Membership:
"""Represents the membership states of a user in a room.""" """Represents the membership states of a user in a room."""
INVITE: Final = "invite" INVITE: Final = "invite"

View file

@ -370,9 +370,11 @@ class RoomVersionCapability:
MSC3244_CAPABILITIES = { MSC3244_CAPABILITIES = {
cap.identifier: { cap.identifier: {
"preferred": cap.preferred_version.identifier "preferred": (
if cap.preferred_version is not None cap.preferred_version.identifier
else None, if cap.preferred_version is not None
else None
),
"support": [ "support": [
v.identifier v.identifier
for v in KNOWN_ROOM_VERSIONS.values() for v in KNOWN_ROOM_VERSIONS.values()

View file

@ -188,9 +188,9 @@ class SynapseHomeServer(HomeServer):
PasswordResetSubmitTokenResource, PasswordResetSubmitTokenResource,
) )
resources[ resources["/_synapse/client/password_reset/email/submit_token"] = (
"/_synapse/client/password_reset/email/submit_token" PasswordResetSubmitTokenResource(self)
] = PasswordResetSubmitTokenResource(self) )
if name == "consent": if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource from synapse.rest.consent.consent_resource import ConsentResource

View file

@ -362,16 +362,16 @@ class ApplicationServiceApi(SimpleHttpClient):
# TODO: Update to stable prefixes once MSC3202 completes FCP merge # TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions: if service.msc3202_transaction_extensions:
if one_time_keys_count: if one_time_keys_count:
body[ body["org.matrix.msc3202.device_one_time_key_counts"] = (
"org.matrix.msc3202.device_one_time_key_counts" one_time_keys_count
] = one_time_keys_count )
body[ body["org.matrix.msc3202.device_one_time_keys_count"] = (
"org.matrix.msc3202.device_one_time_keys_count" one_time_keys_count
] = one_time_keys_count )
if unused_fallback_keys: if unused_fallback_keys:
body[ body["org.matrix.msc3202.device_unused_fallback_key_types"] = (
"org.matrix.msc3202.device_unused_fallback_key_types" unused_fallback_keys
] = unused_fallback_keys )
if device_list_summary: if device_list_summary:
body["org.matrix.msc3202.device_lists"] = { body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed), "changed": list(device_list_summary.changed),

View file

@ -171,9 +171,9 @@ class RegistrationConfig(Config):
refreshable_access_token_lifetime = self.parse_duration( refreshable_access_token_lifetime = self.parse_duration(
refreshable_access_token_lifetime refreshable_access_token_lifetime
) )
self.refreshable_access_token_lifetime: Optional[ self.refreshable_access_token_lifetime: Optional[int] = (
int refreshable_access_token_lifetime
] = refreshable_access_token_lifetime )
if ( if (
self.session_lifetime is not None self.session_lifetime is not None

View file

@ -199,9 +199,9 @@ class ContentRepositoryConfig(Config):
provider_config["module"] == "file_system" provider_config["module"] == "file_system"
or provider_config["module"] == "synapse.rest.media.v1.storage_provider" or provider_config["module"] == "synapse.rest.media.v1.storage_provider"
): ):
provider_config[ provider_config["module"] = (
"module" "synapse.media.storage_provider.FileStorageProviderBackend"
] = "synapse.media.storage_provider.FileStorageProviderBackend" )
provider_class, parsed_config = load_module( provider_class, parsed_config = load_module(
provider_config, ("media_storage_providers", "<item %i>" % i) provider_config, ("media_storage_providers", "<item %i>" % i)

View file

@ -88,8 +88,7 @@ class _EventSourceStore(Protocol):
redact_behaviour: EventRedactBehaviour, redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False, get_prev_content: bool = False,
allow_rejected: bool = False, allow_rejected: bool = False,
) -> Dict[str, "EventBase"]: ) -> Dict[str, "EventBase"]: ...
...
def validate_event_for_room_version(event: "EventBase") -> None: def validate_event_for_room_version(event: "EventBase") -> None:

View file

@ -93,16 +93,14 @@ class DictProperty(Generic[T]):
self, self,
instance: Literal[None], instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None, owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DictProperty": ) -> "DictProperty": ...
...
@overload @overload
def __get__( def __get__(
self, self,
instance: _DictPropertyInstance, instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None, owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T: ) -> T: ...
...
def __get__( def __get__(
self, self,
@ -161,16 +159,14 @@ class DefaultDictProperty(DictProperty, Generic[T]):
self, self,
instance: Literal[None], instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None, owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DefaultDictProperty": ) -> "DefaultDictProperty": ...
...
@overload @overload
def __get__( def __get__(
self, self,
instance: _DictPropertyInstance, instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None, owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T: ) -> T: ...
...
def __get__( def __get__(
self, self,

View file

@ -612,9 +612,9 @@ class EventClientSerializer:
serialized_aggregations = {} serialized_aggregations = {}
if event_aggregations.references: if event_aggregations.references:
serialized_aggregations[ serialized_aggregations[RelationTypes.REFERENCE] = (
RelationTypes.REFERENCE event_aggregations.references
] = event_aggregations.references )
if event_aggregations.replace: if event_aggregations.replace:
# Include information about it in the relations dict. # Include information about it in the relations dict.

View file

@ -169,9 +169,9 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache: ResponseCache[ self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = (
Tuple[str, Optional[str]] ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) )
self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000 hs.get_clock(), "state_ids_resp", timeout_ms=30000
) )

View file

@ -88,9 +88,9 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# Stores the destinations we need to explicitly send presence to about a # Stores the destinations we need to explicitly send presence to about a
# given user. # given user.
# Stream position -> (user_id, destinations) # Stream position -> (user_id, destinations)
self.presence_destinations: SortedDict[ self.presence_destinations: SortedDict[int, Tuple[str, Iterable[str]]] = (
int, Tuple[str, Iterable[str]] SortedDict()
] = SortedDict() )
# (destination, key) -> EDU # (destination, key) -> EDU
self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {} self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}

View file

@ -118,10 +118,10 @@ class AccountHandler:
} }
if self._use_account_validity_in_account_status: if self._use_account_validity_in_account_status:
status[ status["org.matrix.expired"] = (
"org.matrix.expired" await self._account_validity_handler.is_user_expired(
] = await self._account_validity_handler.is_user_expired( user_id.to_string()
user_id.to_string() )
) )
return status return status

View file

@ -265,9 +265,9 @@ class DirectoryHandler:
async def get_association(self, room_alias: RoomAlias) -> JsonDict: async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result: Optional[ result: Optional[RoomAliasMapping] = (
RoomAliasMapping await self.get_association_from_room_alias(room_alias)
] = await self.get_association_from_room_alias(room_alias) )
if result: if result:
room_id = result.room_id room_id = result.room_id

View file

@ -1001,11 +1001,11 @@ class FederationHandler:
) )
if include_auth_user_id: if include_auth_user_id:
event_content[ event_content[EventContentFields.AUTHORISING_USER] = (
EventContentFields.AUTHORISING_USER await self._event_auth_handler.get_user_which_could_invite(
] = await self._event_auth_handler.get_user_which_could_invite( room_id,
room_id, state_ids,
state_ids, )
) )
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(

View file

@ -1367,9 +1367,9 @@ class FederationEventHandler:
) )
if remote_event.is_state() and remote_event.rejected_reason is None: if remote_event.is_state() and remote_event.rejected_reason is None:
state_map[ state_map[(remote_event.type, remote_event.state_key)] = (
(remote_event.type, remote_event.state_key) remote_event.event_id
] = remote_event.event_id )
return state_map return state_map

View file

@ -1654,9 +1654,9 @@ class EventCreationHandler:
expiry_ms=60 * 60 * 1000, expiry_ms=60 * 60 * 1000,
) )
self._external_cache_joined_hosts_updates[ self._external_cache_joined_hosts_updates[state_entry.state_group] = (
state_entry.state_group None
] = None )
async def _validate_canonical_alias( async def _validate_canonical_alias(
self, self,

View file

@ -493,9 +493,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by (user ID, device ID). # The number of ongoing syncs on this process, by (user ID, device ID).
# Empty if _presence_enabled is false. # Empty if _presence_enabled is false.
self._user_device_to_num_current_syncs: Dict[ self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
Tuple[str, Optional[str]], int {}
] = {} )
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id() self.instance_id = hs.get_instance_id()
@ -818,9 +818,9 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While # Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline. # this is non zero a user will never go offline.
self._user_device_to_num_current_syncs: Dict[ self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = (
Tuple[str, Optional[str]], int {}
] = {} )
# Keeps track of the number of *ongoing* syncs on other processes. # Keeps track of the number of *ongoing* syncs on other processes.
# #

View file

@ -320,9 +320,9 @@ class ProfileHandler:
server_name = host server_name = host
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media_info: Optional[ media_info: Optional[Union[LocalMedia, RemoteMedia]] = (
Union[LocalMedia, RemoteMedia] await self.store.get_local_media(media_id)
] = await self.store.get_local_media(media_id) )
else: else:
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)

View file

@ -188,13 +188,13 @@ class RelationsHandler:
if include_original_event: if include_original_event:
# Do not bundle aggregations when retrieving the original event because # Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it. # we want the content before relations are applied to it.
return_value[ return_value["original_event"] = (
"original_event" await self._event_serializer.serialize_event(
] = await self._event_serializer.serialize_event( event,
event, now,
now, bundle_aggregations=None,
bundle_aggregations=None, config=serialize_options,
config=serialize_options, )
) )
if next_token: if next_token:

View file

@ -538,10 +538,10 @@ class RoomCreationHandler:
# deep-copy the power-levels event before we start modifying it # deep-copy the power-levels event before we start modifying it
# note that if frozen_dicts are enabled, `power_levels` will be a frozen # note that if frozen_dicts are enabled, `power_levels` will be a frozen
# dict so we can't just copy.deepcopy it. # dict so we can't just copy.deepcopy it.
initial_state[ initial_state[(EventTypes.PowerLevels, "")] = power_levels = (
(EventTypes.PowerLevels, "") copy_and_fixup_power_levels_contents(
] = power_levels = copy_and_fixup_power_levels_contents( initial_state[(EventTypes.PowerLevels, "")]
initial_state[(EventTypes.PowerLevels, "")] )
) )
# Resolve the minimum power level required to send any state event # Resolve the minimum power level required to send any state event
@ -1362,9 +1362,11 @@ class RoomCreationHandler:
visibility = room_config.get("visibility", "private") visibility = room_config.get("visibility", "private")
preset_name = room_config.get( preset_name = room_config.get(
"preset", "preset",
RoomCreationPreset.PRIVATE_CHAT (
if visibility == "private" RoomCreationPreset.PRIVATE_CHAT
else RoomCreationPreset.PUBLIC_CHAT, if visibility == "private"
else RoomCreationPreset.PUBLIC_CHAT
),
) )
try: try:
preset_config = self._presets_dict[preset_name] preset_config = self._presets_dict[preset_name]

View file

@ -1236,11 +1236,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If this is going to be a local join, additional information must # If this is going to be a local join, additional information must
# be included in the event content in order to efficiently validate # be included in the event content in order to efficiently validate
# the event. # the event.
content[ content[EventContentFields.AUTHORISING_USER] = (
EventContentFields.AUTHORISING_USER await self.event_auth_handler.get_user_which_could_invite(
] = await self.event_auth_handler.get_user_which_could_invite( room_id,
room_id, state_before_join,
state_before_join, )
) )
return False, [] return False, []

View file

@ -1333,9 +1333,9 @@ class SyncHandler:
and auth_event.state_key == member and auth_event.state_key == member
): ):
missing_members.discard(member) missing_members.discard(member)
additional_state_ids[ additional_state_ids[(EventTypes.Member, member)] = (
(EventTypes.Member, member) auth_event.event_id
] = auth_event.event_id )
break break
if missing_members: if missing_members:

View file

@ -931,8 +931,7 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None, parser: Literal[None] = None,
backoff_on_all_error_codes: bool = False, backoff_on_all_error_codes: bool = False,
) -> JsonDict: ) -> JsonDict: ...
...
@overload @overload
async def put_json( async def put_json(
@ -949,8 +948,7 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None, parser: Optional[ByteParser[T]] = None,
backoff_on_all_error_codes: bool = False, backoff_on_all_error_codes: bool = False,
) -> T: ) -> T: ...
...
async def put_json( async def put_json(
self, self,
@ -1140,8 +1138,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None, parser: Literal[None] = None,
) -> JsonDict: ) -> JsonDict: ...
...
@overload @overload
async def get_json( async def get_json(
@ -1154,8 +1151,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = ..., ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ..., try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ..., parser: ByteParser[T] = ...,
) -> T: ) -> T: ...
...
async def get_json( async def get_json(
self, self,
@ -1236,8 +1232,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None, parser: Literal[None] = None,
) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: ...
...
@overload @overload
async def get_json_with_headers( async def get_json_with_headers(
@ -1250,8 +1245,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = ..., ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ..., try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ..., parser: ByteParser[T] = ...,
) -> Tuple[T, Dict[bytes, List[bytes]]]: ) -> Tuple[T, Dict[bytes, List[bytes]]]: ...
...
async def get_json_with_headers( async def get_json_with_headers(
self, self,

View file

@ -61,20 +61,17 @@ logger = logging.getLogger(__name__)
@overload @overload
def parse_integer(request: Request, name: str, default: int) -> int: def parse_integer(request: Request, name: str, default: int) -> int: ...
...
@overload @overload
def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: ...
...
@overload @overload
def parse_integer( def parse_integer(
request: Request, name: str, default: Optional[int] = None, required: bool = False request: Request, name: str, default: Optional[int] = None, required: bool = False
) -> Optional[int]: ) -> Optional[int]: ...
...
def parse_integer( def parse_integer(
@ -105,8 +102,7 @@ def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]], args: Mapping[bytes, Sequence[bytes]],
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
) -> Optional[int]: ) -> Optional[int]: ...
...
@overload @overload
@ -115,8 +111,7 @@ def parse_integer_from_args(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
) -> int: ) -> int: ...
...
@overload @overload
@ -125,8 +120,7 @@ def parse_integer_from_args(
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
required: bool = False, required: bool = False,
) -> Optional[int]: ) -> Optional[int]: ...
...
def parse_integer_from_args( def parse_integer_from_args(
@ -172,20 +166,17 @@ def parse_integer_from_args(
@overload @overload
def parse_boolean(request: Request, name: str, default: bool) -> bool: def parse_boolean(request: Request, name: str, default: bool) -> bool: ...
...
@overload @overload
def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: ...
...
@overload @overload
def parse_boolean( def parse_boolean(
request: Request, name: str, default: Optional[bool] = None, required: bool = False request: Request, name: str, default: Optional[bool] = None, required: bool = False
) -> Optional[bool]: ) -> Optional[bool]: ...
...
def parse_boolean( def parse_boolean(
@ -216,8 +207,7 @@ def parse_boolean_from_args(
args: Mapping[bytes, Sequence[bytes]], args: Mapping[bytes, Sequence[bytes]],
name: str, name: str,
default: bool, default: bool,
) -> bool: ) -> bool: ...
...
@overload @overload
@ -226,8 +216,7 @@ def parse_boolean_from_args(
name: str, name: str,
*, *,
required: Literal[True], required: Literal[True],
) -> bool: ) -> bool: ...
...
@overload @overload
@ -236,8 +225,7 @@ def parse_boolean_from_args(
name: str, name: str,
default: Optional[bool] = None, default: Optional[bool] = None,
required: bool = False, required: bool = False,
) -> Optional[bool]: ) -> Optional[bool]: ...
...
def parse_boolean_from_args( def parse_boolean_from_args(
@ -289,8 +277,7 @@ def parse_bytes_from_args(
args: Mapping[bytes, Sequence[bytes]], args: Mapping[bytes, Sequence[bytes]],
name: str, name: str,
default: Optional[bytes] = None, default: Optional[bytes] = None,
) -> Optional[bytes]: ) -> Optional[bytes]: ...
...
@overload @overload
@ -300,8 +287,7 @@ def parse_bytes_from_args(
default: Literal[None] = None, default: Literal[None] = None,
*, *,
required: Literal[True], required: Literal[True],
) -> bytes: ) -> bytes: ...
...
@overload @overload
@ -310,8 +296,7 @@ def parse_bytes_from_args(
name: str, name: str,
default: Optional[bytes] = None, default: Optional[bytes] = None,
required: bool = False, required: bool = False,
) -> Optional[bytes]: ) -> Optional[bytes]: ...
...
def parse_bytes_from_args( def parse_bytes_from_args(
@ -355,8 +340,7 @@ def parse_string(
*, *,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str: ...
...
@overload @overload
@ -367,8 +351,7 @@ def parse_string(
required: Literal[True], required: Literal[True],
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str: ...
...
@overload @overload
@ -380,8 +363,7 @@ def parse_string(
required: bool = False, required: bool = False,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]: ...
...
def parse_string( def parse_string(
@ -437,8 +419,7 @@ def parse_enum(
name: str, name: str,
E: Type[EnumT], E: Type[EnumT],
default: EnumT, default: EnumT,
) -> EnumT: ) -> EnumT: ...
...
@overload @overload
@ -448,8 +429,7 @@ def parse_enum(
E: Type[EnumT], E: Type[EnumT],
*, *,
required: Literal[True], required: Literal[True],
) -> EnumT: ) -> EnumT: ...
...
def parse_enum( def parse_enum(
@ -526,8 +506,7 @@ def parse_strings_from_args(
*, *,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]: ...
...
@overload @overload
@ -538,8 +517,7 @@ def parse_strings_from_args(
*, *,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]: ...
...
@overload @overload
@ -550,8 +528,7 @@ def parse_strings_from_args(
required: Literal[True], required: Literal[True],
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> List[str]: ) -> List[str]: ...
...
@overload @overload
@ -563,8 +540,7 @@ def parse_strings_from_args(
required: bool = False, required: bool = False,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[List[str]]: ) -> Optional[List[str]]: ...
...
def parse_strings_from_args( def parse_strings_from_args(
@ -625,8 +601,7 @@ def parse_string_from_args(
*, *,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]: ...
...
@overload @overload
@ -638,8 +613,7 @@ def parse_string_from_args(
required: Literal[True], required: Literal[True],
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> str: ) -> str: ...
...
@overload @overload
@ -650,8 +624,7 @@ def parse_string_from_args(
required: bool = False, required: bool = False,
allowed_values: Optional[StrCollection] = None, allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii", encoding: str = "ascii",
) -> Optional[str]: ) -> Optional[str]: ...
...
def parse_string_from_args( def parse_string_from_args(
@ -704,22 +677,19 @@ def parse_string_from_args(
@overload @overload
def parse_json_value_from_request(request: Request) -> JsonDict: def parse_json_value_from_request(request: Request) -> JsonDict: ...
...
@overload @overload
def parse_json_value_from_request( def parse_json_value_from_request(
request: Request, allow_empty_body: Literal[False] request: Request, allow_empty_body: Literal[False]
) -> JsonDict: ) -> JsonDict: ...
...
@overload @overload
def parse_json_value_from_request( def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]: ) -> Optional[JsonDict]: ...
...
def parse_json_value_from_request( def parse_json_value_from_request(
@ -847,7 +817,6 @@ def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
class RestServlet: class RestServlet:
"""A Synapse REST Servlet. """A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method, An implementing class can either provide its own custom 'register' method,

View file

@ -744,8 +744,7 @@ def preserve_fn(
@overload @overload
def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: ...
...
def preserve_fn( def preserve_fn(
@ -774,8 +773,7 @@ def run_in_background(
@overload @overload
def run_in_background( def run_in_background(
f: Callable[P, R], *args: P.args, **kwargs: P.kwargs f: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> "defer.Deferred[R]": ) -> "defer.Deferred[R]": ...
...
def run_in_background( # type: ignore[misc] def run_in_background( # type: ignore[misc]

View file

@ -388,15 +388,13 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@overload @overload
def ensure_active_span( def ensure_active_span(
message: str, message: str,
) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: ) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]: ...
...
@overload @overload
def ensure_active_span( def ensure_active_span(
message: str, ret: T message: str, ret: T
) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]: ) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]: ...
...
def ensure_active_span( def ensure_active_span(

View file

@ -1002,9 +1002,9 @@ class MediaRepository:
) )
t_width = min(m_width, t_width) t_width = min(m_width, t_width)
t_height = min(m_height, t_height) t_height = min(m_height, t_height)
thumbnails[ thumbnails[(t_width, t_height, requirement.media_type)] = (
(t_width, t_height, requirement.media_type) requirement.method
] = requirement.method )
# Now we generate the thumbnails for each dimension, store it # Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items(): for (t_width, t_height, t_type), t_method in thumbnails.items():

View file

@ -42,14 +42,12 @@ class JemallocStats:
@overload @overload
def _mallctl( def _mallctl(
self, name: str, read: Literal[True] = True, write: Optional[int] = None self, name: str, read: Literal[True] = True, write: Optional[int] = None
) -> int: ) -> int: ...
...
@overload @overload
def _mallctl( def _mallctl(
self, name: str, read: Literal[False], write: Optional[int] = None self, name: str, read: Literal[False], write: Optional[int] = None
) -> None: ) -> None: ...
...
def _mallctl( def _mallctl(
self, name: str, read: bool = True, write: Optional[int] = None self, name: str, read: bool = True, write: Optional[int] = None

View file

@ -469,8 +469,7 @@ class Notifier:
new_token: RoomStreamToken, new_token: RoomStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None: ...
...
@overload @overload
def on_new_event( def on_new_event(
@ -479,8 +478,7 @@ class Notifier:
new_token: MultiWriterStreamToken, new_token: MultiWriterStreamToken,
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None: ...
...
@overload @overload
def on_new_event( def on_new_event(
@ -497,8 +495,7 @@ class Notifier:
new_token: int, new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None, rooms: Optional[StrCollection] = None,
) -> None: ) -> None: ...
...
def on_new_event( def on_new_event(
self, self,

View file

@ -377,12 +377,14 @@ class Mailer:
# #
# Note that many email clients will not render the unsubscribe link # Note that many email clients will not render the unsubscribe link
# unless DKIM, etc. is properly setup. # unless DKIM, etc. is properly setup.
additional_headers={ additional_headers=(
"List-Unsubscribe-Post": "List-Unsubscribe=One-Click", {
"List-Unsubscribe": f"<{unsubscribe_link}>", "List-Unsubscribe-Post": "List-Unsubscribe=One-Click",
} "List-Unsubscribe": f"<{unsubscribe_link}>",
if unsubscribe_link }
else None, if unsubscribe_link
else None
),
) )
async def _get_room_vars( async def _get_room_vars(

View file

@ -259,9 +259,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
url_args.append(txn_id) url_args.append(txn_id)
if cls.METHOD == "POST": if cls.METHOD == "POST":
request_func: Callable[ request_func: Callable[..., Awaitable[Any]] = (
..., Awaitable[Any] client.post_json_get_json
] = client.post_json_get_json )
elif cls.METHOD == "PUT": elif cls.METHOD == "PUT":
request_func = client.put_json request_func = client.put_json
elif cls.METHOD == "GET": elif cls.METHOD == "GET":

View file

@ -70,9 +70,9 @@ class ExternalCache:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
self._redis_connection: Optional[ self._redis_connection: Optional["ConnectionHandler"] = (
"ConnectionHandler" hs.get_outbound_redis_connection()
] = hs.get_outbound_redis_connection() )
else: else:
self._redis_connection = None self._redis_connection = None

View file

@ -237,10 +237,12 @@ class PurgeHistoryStatusRestServlet(RestServlet):
raise NotFoundError("purge id '%s' not found" % purge_id) raise NotFoundError("purge id '%s' not found" % purge_id)
result: JsonDict = { result: JsonDict = {
"status": purge_task.status "status": (
if purge_task.status == TaskStatus.COMPLETE purge_task.status
or purge_task.status == TaskStatus.FAILED if purge_task.status == TaskStatus.COMPLETE
else "active", or purge_task.status == TaskStatus.FAILED
else "active"
),
} }
if purge_task.error: if purge_task.error:
result["error"] = purge_task.error result["error"] = purge_task.error

View file

@ -1184,12 +1184,14 @@ class RateLimitRestServlet(RestServlet):
# convert `null` to `0` for consistency # convert `null` to `0` for consistency
# both values do the same in retelimit handler # both values do the same in retelimit handler
ret = { ret = {
"messages_per_second": 0 "messages_per_second": (
if ratelimit.messages_per_second is None 0
else ratelimit.messages_per_second, if ratelimit.messages_per_second is None
"burst_count": 0 else ratelimit.messages_per_second
if ratelimit.burst_count is None ),
else ratelimit.burst_count, "burst_count": (
0 if ratelimit.burst_count is None else ratelimit.burst_count
),
} }
else: else:
ret = {} ret = {}

View file

@ -112,9 +112,9 @@ class AccountDataServlet(RestServlet):
self._hs.config.experimental.msc4010_push_rules_account_data self._hs.config.experimental.msc4010_push_rules_account_data
and account_data_type == AccountDataTypes.PUSH_RULES and account_data_type == AccountDataTypes.PUSH_RULES
): ):
account_data: Optional[ account_data: Optional[JsonMapping] = (
JsonMapping await self._push_rules_handler.push_rules_for_user(requester.user)
] = await self._push_rules_handler.push_rules_for_user(requester.user) )
else: else:
account_data = await self.store.get_global_account_data_by_type_for_user( account_data = await self.store.get_global_account_data_by_type_for_user(
user_id, account_data_type user_id, account_data_type

View file

@ -313,12 +313,12 @@ class SyncRestServlet(RestServlet):
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the server supports the feature. # states that this field should always be included, as long as the server supports the feature.
response[ response["org.matrix.msc2732.device_unused_fallback_key_types"] = (
"org.matrix.msc2732.device_unused_fallback_key_types" sync_result.device_unused_fallback_key_types
] = sync_result.device_unused_fallback_key_types )
response[ response["device_unused_fallback_key_types"] = (
"device_unused_fallback_key_types" sync_result.device_unused_fallback_key_types
] = sync_result.device_unused_fallback_key_types )
if joined: if joined:
response["rooms"][Membership.JOIN] = joined response["rooms"][Membership.JOIN] = joined
@ -543,9 +543,9 @@ class SyncRestServlet(RestServlet):
if room.unread_thread_notifications: if room.unread_thread_notifications:
result["unread_thread_notifications"] = room.unread_thread_notifications result["unread_thread_notifications"] = room.unread_thread_notifications
if self._msc3773_enabled: if self._msc3773_enabled:
result[ result["org.matrix.msc3773.unread_thread_notifications"] = (
"org.matrix.msc3773.unread_thread_notifications" room.unread_thread_notifications
] = room.unread_thread_notifications )
result["summary"] = room.summary result["summary"] = room.summary
if self._msc2654_enabled: if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count result["org.matrix.msc2654.unread_count"] = room.unread_count

View file

@ -191,10 +191,10 @@ class RemoteKey(RestServlet):
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if key_ids: if key_ids:
results: Mapping[ results: Mapping[str, Optional[FetchKeyResultForRemote]] = (
str, Optional[FetchKeyResultForRemote] await self.store.get_server_keys_json_for_remote(
] = await self.store.get_server_keys_json_for_remote( server_name, key_ids
server_name, key_ids )
) )
else: else:
results = await self.store.get_all_server_keys_json_for_remote( results = await self.store.get_all_server_keys_json_for_remote(

View file

@ -603,15 +603,15 @@ class StateResolutionHandler:
self.resolve_linearizer = Linearizer(name="state_resolve_lock") self.resolve_linearizer = Linearizer(name="state_resolve_lock")
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache: ExpiringCache[ self._state_cache: ExpiringCache[FrozenSet[int], _StateCacheEntry] = (
FrozenSet[int], _StateCacheEntry ExpiringCache(
] = ExpiringCache( cache_name="state_cache",
cache_name="state_cache", clock=self.clock,
clock=self.clock, max_len=100000,
max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True,
iterable=True, reset_expiry_on_get=True,
reset_expiry_on_get=True, )
) )
# #

View file

@ -52,8 +52,7 @@ class Clock(Protocol):
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests. # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
# We only ever sleep(0) though, so that other async functions can make forward # We only ever sleep(0) though, so that other async functions can make forward
# progress without waiting for stateres to complete. # progress without waiting for stateres to complete.
def sleep(self, duration_ms: float) -> Awaitable[None]: def sleep(self, duration_ms: float) -> Awaitable[None]: ...
...
class StateResolutionStore(Protocol): class StateResolutionStore(Protocol):
@ -61,13 +60,11 @@ class StateResolutionStore(Protocol):
# TestStateResolutionStore in tests. # TestStateResolutionStore in tests.
def get_events( def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]: ) -> Awaitable[Dict[str, EventBase]]: ...
...
def get_auth_chain_difference( def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]] self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]: ) -> Awaitable[Set[str]]: ...
...
# We want to await to the reactor occasionally during state res when dealing # We want to await to the reactor occasionally during state res when dealing
@ -742,8 +739,7 @@ async def _get_event(
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore, state_res_store: StateResolutionStore,
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
) -> EventBase: ) -> EventBase: ...
...
@overload @overload
@ -753,8 +749,7 @@ async def _get_event(
event_map: Dict[str, EventBase], event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore, state_res_store: StateResolutionStore,
allow_none: Literal[True], allow_none: Literal[True],
) -> Optional[EventBase]: ) -> Optional[EventBase]: ...
...
async def _get_event( async def _get_event(

View file

@ -836,9 +836,9 @@ class BackgroundUpdater:
c.execute(sql) c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine): if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner: Optional[ runner: Optional[Callable[[LoggingDatabaseConnection], None]] = (
Callable[[LoggingDatabaseConnection], None] create_index_psql
] = create_index_psql )
elif psql_only: elif psql_only:
runner = None runner = None
else: else:

View file

@ -773,9 +773,9 @@ class EventsPersistenceStorageController:
) )
# Remove any events which are prev_events of any existing events. # Remove any events which are prev_events of any existing events.
existing_prevs: Collection[ existing_prevs: Collection[str] = (
str await self.persist_events_store._get_events_which_are_prevs(result)
] = await self.persist_events_store._get_events_which_are_prevs(result) )
result.difference_update(existing_prevs) result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev # Finally handle the case where the new events have soft-failed prev

View file

@ -111,8 +111,7 @@ class _PoolConnection(Connection):
A Connection from twisted.enterprise.adbapi.Connection. A Connection from twisted.enterprise.adbapi.Connection.
""" """
def reconnect(self) -> None: def reconnect(self) -> None: ...
...
def make_pool( def make_pool(
@ -1603,8 +1602,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Tuple[Any, ...]: ) -> Tuple[Any, ...]: ...
...
@overload @overload
async def simple_select_one( async def simple_select_one(
@ -1614,8 +1612,7 @@ class DatabasePool:
retcols: Collection[str], retcols: Collection[str],
allow_none: Literal[True] = True, allow_none: Literal[True] = True,
desc: str = "simple_select_one", desc: str = "simple_select_one",
) -> Optional[Tuple[Any, ...]]: ) -> Optional[Tuple[Any, ...]]: ...
...
async def simple_select_one( async def simple_select_one(
self, self,
@ -1654,8 +1651,7 @@ class DatabasePool:
retcol: str, retcol: str,
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol", desc: str = "simple_select_one_onecol",
) -> Any: ) -> Any: ...
...
@overload @overload
async def simple_select_one_onecol( async def simple_select_one_onecol(
@ -1665,8 +1661,7 @@ class DatabasePool:
retcol: str, retcol: str,
allow_none: Literal[True] = True, allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol", desc: str = "simple_select_one_onecol",
) -> Optional[Any]: ) -> Optional[Any]: ...
...
async def simple_select_one_onecol( async def simple_select_one_onecol(
self, self,
@ -1706,8 +1701,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcol: str, retcol: str,
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
) -> Any: ) -> Any: ...
...
@overload @overload
@classmethod @classmethod
@ -1718,8 +1712,7 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
retcol: str, retcol: str,
allow_none: Literal[True] = True, allow_none: Literal[True] = True,
) -> Optional[Any]: ) -> Optional[Any]: ...
...
@classmethod @classmethod
def simple_select_one_onecol_txn( def simple_select_one_onecol_txn(
@ -2501,8 +2494,7 @@ def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
columns: Tuple[str, str], columns: Tuple[str, str],
iterable: Collection[Tuple[Any, Any]], iterable: Collection[Tuple[Any, Any]],
) -> Tuple[str, list]: ) -> Tuple[str, list]: ...
...
def make_tuple_in_list_sql_clause( def make_tuple_in_list_sql_clause(

View file

@ -1701,9 +1701,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache: LruCache[ self.device_id_exists_cache: LruCache[Tuple[str, str], Literal[True]] = (
Tuple[str, str], Literal[True] LruCache(cache_name="device_id_exists", max_size=10000)
] = LruCache(cache_name="device_id_exists", max_size=10000) )
async def store_device( async def store_device(
self, self,

View file

@ -256,8 +256,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
self, self,
query_list: Collection[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[False] = False, include_all_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ...
...
@overload @overload
async def get_e2e_device_keys_and_signatures( async def get_e2e_device_keys_and_signatures(
@ -265,8 +264,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: Collection[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False, include_all_devices: bool = False,
include_deleted_devices: Literal[False] = False, include_deleted_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: ...
...
@overload @overload
async def get_e2e_device_keys_and_signatures( async def get_e2e_device_keys_and_signatures(
@ -274,8 +272,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_list: Collection[Tuple[str, Optional[str]]], query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[True], include_all_devices: Literal[True],
include_deleted_devices: Literal[True], include_deleted_devices: Literal[True],
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ...
...
@trace @trace
@cancellable @cancellable

View file

@ -1292,9 +1292,9 @@ class PersistEventsStore:
Returns: Returns:
filtered list filtered list
""" """
new_events_and_contexts: OrderedDict[ new_events_and_contexts: OrderedDict[str, Tuple[EventBase, EventContext]] = (
str, Tuple[EventBase, EventContext] OrderedDict()
] = OrderedDict() )
for event, context in events_and_contexts: for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id) prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context: if prev_event_context:

View file

@ -263,13 +263,13 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000, 5 * 60 * 1000,
) )
self._get_event_cache: AsyncLruCache[ self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = (
Tuple[str], EventCacheEntry AsyncLruCache(
] = AsyncLruCache( cache_name="*getEvent*",
cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size,
max_size=hs.config.caches.event_cache_size, # `extra_index_cb` Returns a tuple as that is the key type
# `extra_index_cb` Returns a tuple as that is the key type extra_index_cb=lambda _, v: (v.event.room_id,),
extra_index_cb=lambda _, v: (v.event.room_id,), )
) )
# Map from event ID to a deferred that will result in a map from event # Map from event ID to a deferred that will result in a map from event
@ -459,8 +459,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = ..., allow_rejected: bool = ...,
allow_none: Literal[False] = ..., allow_none: Literal[False] = ...,
check_room_id: Optional[str] = ..., check_room_id: Optional[str] = ...,
) -> EventBase: ) -> EventBase: ...
...
@overload @overload
async def get_event( async def get_event(
@ -471,8 +470,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = ..., allow_rejected: bool = ...,
allow_none: Literal[True] = ..., allow_none: Literal[True] = ...,
check_room_id: Optional[str] = ..., check_room_id: Optional[str] = ...,
) -> Optional[EventBase]: ) -> Optional[EventBase]: ...
...
@cancellable @cancellable
async def get_event( async def get_event(
@ -800,9 +798,9 @@ class EventsWorkerStore(SQLBaseStore):
# to all the events we pulled from the DB (this will result in this # to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen # function returning more events than requested, but that can happen
# already due to `_get_events_from_db`). # already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[ fetching_deferred: ObservableDeferred[Dict[str, EventCacheEntry]] = (
Dict[str, EventCacheEntry] ObservableDeferred(defer.Deferred(), consumeErrors=True)
] = ObservableDeferred(defer.Deferred(), consumeErrors=True) )
for event_id in missing_events_ids: for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred self._current_event_fetches[event_id] = fetching_deferred
@ -1871,9 +1869,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?" " LIMIT ?"
) )
txn.execute(sql, (-last_id, -current_id, instance_name, limit)) txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates: List[ new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = (
Tuple[int, Tuple[str, str, str, str, str, str]] []
] = [] )
row: Tuple[int, str, str, str, str, str, str] row: Tuple[int, str, str, str, str, str, str]
# Type safety: iterating over `txn` yields `Tuple`, i.e. # Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a

View file

@ -79,9 +79,9 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to lock that we think we # A map from `(lock_name, lock_key)` to lock that we think we
# currently hold. # currently hold.
self._live_lock_tokens: WeakValueDictionary[ self._live_lock_tokens: WeakValueDictionary[Tuple[str, str], Lock] = (
Tuple[str, str], Lock WeakValueDictionary()
] = WeakValueDictionary() )
# A map from `(lock_name, lock_key, token)` to read/write lock that we # A map from `(lock_name, lock_key, token)` to read/write lock that we
# think we currently hold. For a given lock_name/lock_key, there can be # think we currently hold. For a given lock_name/lock_key, there can be

View file

@ -158,9 +158,9 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
) )
if hs.config.media.can_load_media_repo: if hs.config.media.can_load_media_repo:
self.unused_expiration_time: Optional[ self.unused_expiration_time: Optional[int] = (
int hs.config.media.unused_expiration_time
] = hs.config.media.unused_expiration_time )
else: else:
self.unused_expiration_time = None self.unused_expiration_time = None

View file

@ -394,9 +394,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
content: JsonDict = {} content: JsonDict = {}
for receipt_type, user_id, event_id, data in rows: for receipt_type, user_id, event_id, data in rows:
content.setdefault(event_id, {}).setdefault(receipt_type, {})[ content.setdefault(event_id, {}).setdefault(receipt_type, {})[user_id] = (
user_id db_to_json(data)
] = db_to_json(data) )
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@ -483,9 +483,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
if user_id in receipt_type_dict: # existing receipt if user_id in receipt_type_dict: # existing receipt
# is the existing receipt threaded and we are currently processing an unthreaded one? # is the existing receipt threaded and we are currently processing an unthreaded one?
if "thread_id" in receipt_type_dict[user_id] and not thread_id: if "thread_id" in receipt_type_dict[user_id] and not thread_id:
receipt_type_dict[ receipt_type_dict[user_id] = (
user_id receipt_data # replace with unthreaded one
] = receipt_data # replace with unthreaded one )
else: # receipt does not exist, just set it else: # receipt does not exist, just set it
receipt_type_dict[user_id] = receipt_data receipt_type_dict[user_id] = receipt_data
if thread_id: if thread_id:

View file

@ -768,12 +768,10 @@ class StateMapWrapper(Dict[StateKey, str]):
return super().__getitem__(key) return super().__getitem__(key)
@overload @overload
def get(self, key: Tuple[str, str]) -> Optional[str]: def get(self, key: Tuple[str, str]) -> Optional[str]: ...
...
@overload @overload
def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: ...
...
def get( def get(
self, key: StateKey, default: Union[str, _T, None] = None self, key: StateKey, default: Union[str, _T, None] = None

View file

@ -988,8 +988,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn: LoggingTransaction, txn: LoggingTransaction,
event_id: str, event_id: str,
allow_none: Literal[False] = False, allow_none: Literal[False] = False,
) -> int: ) -> int: ...
...
@overload @overload
def get_stream_id_for_event_txn( def get_stream_id_for_event_txn(
@ -997,8 +996,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn: LoggingTransaction, txn: LoggingTransaction,
event_id: str, event_id: str,
allow_none: bool = False, allow_none: bool = False,
) -> Optional[int]: ) -> Optional[int]: ...
...
def get_stream_id_for_event_txn( def get_stream_id_for_event_txn(
self, self,
@ -1476,12 +1474,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering) _EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results( if _filter_results(
lower_token=to_token lower_token=(
if direction == Direction.BACKWARDS to_token if direction == Direction.BACKWARDS else from_token
else from_token, ),
upper_token=from_token upper_token=(
if direction == Direction.BACKWARDS from_token if direction == Direction.BACKWARDS else to_token
else to_token, ),
instance_name=instance_name, instance_name=instance_name,
topological_ordering=topological_ordering, topological_ordering=topological_ordering,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,

View file

@ -136,12 +136,12 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
"status": task.status, "status": task.status,
"timestamp": task.timestamp, "timestamp": task.timestamp,
"resource_id": task.resource_id, "resource_id": task.resource_id,
"params": None "params": (
if task.params is None None if task.params is None else json_encoder.encode(task.params)
else json_encoder.encode(task.params), ),
"result": None "result": (
if task.result is None None if task.result is None else json_encoder.encode(task.result)
else json_encoder.encode(task.result), ),
"error": task.error, "error": task.error,
}, },
desc="insert_scheduled_task", desc="insert_scheduled_task",

View file

@ -745,9 +745,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
p.user_id, p.user_id,
get_localpart_from_id(p.user_id), get_localpart_from_id(p.user_id),
get_domain_from_id(p.user_id), get_domain_from_id(p.user_id),
_filter_text_for_index(p.display_name) (
if p.display_name _filter_text_for_index(p.display_name)
else None, if p.display_name
else None
),
) )
for p in profiles for p in profiles
], ],

View file

@ -120,11 +120,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# TODO: this hasn't been tuned yet # TODO: this hasn't been tuned yet
50000, 50000,
) )
self._state_group_members_cache: DictionaryCache[ self._state_group_members_cache: DictionaryCache[int, StateKey, str] = (
int, StateKey, str DictionaryCache(
] = DictionaryCache( "*stateGroupMembersCache*",
"*stateGroupMembersCache*", 500000,
500000, )
) )
def get_max_state_group_txn(txn: Cursor) -> int: def get_max_state_group_txn(txn: Cursor) -> int:

View file

@ -48,8 +48,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
@property @property
@abc.abstractmethod @abc.abstractmethod
def single_threaded(self) -> bool: def single_threaded(self) -> bool: ...
...
@property @property
@abc.abstractmethod @abc.abstractmethod
@ -68,8 +67,7 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
@abc.abstractmethod @abc.abstractmethod
def check_database( def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False self, db_conn: ConnectionType, allow_outdated_version: bool = False
) -> None: ) -> None: ...
...
@abc.abstractmethod @abc.abstractmethod
def check_new_database(self, txn: CursorType) -> None: def check_new_database(self, txn: CursorType) -> None:
@ -79,27 +77,22 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
... ...
@abc.abstractmethod @abc.abstractmethod
def convert_param_style(self, sql: str) -> str: def convert_param_style(self, sql: str) -> str: ...
...
# This method would ideally take a plain ConnectionType, but it seems that # This method would ideally take a plain ConnectionType, but it seems that
# the Sqlite engine expects to use LoggingDatabaseConnection.cursor # the Sqlite engine expects to use LoggingDatabaseConnection.cursor
# instead of sqlite3.Connection.cursor: only the former takes a txn_name. # instead of sqlite3.Connection.cursor: only the former takes a txn_name.
@abc.abstractmethod @abc.abstractmethod
def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: ...
...
@abc.abstractmethod @abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool: def is_deadlock(self, error: Exception) -> bool: ...
...
@abc.abstractmethod @abc.abstractmethod
def is_connection_closed(self, conn: ConnectionType) -> bool: def is_connection_closed(self, conn: ConnectionType) -> bool: ...
...
@abc.abstractmethod @abc.abstractmethod
def lock_table(self, txn: Cursor, table: str) -> None: def lock_table(self, txn: Cursor, table: str) -> None: ...
...
@property @property
@abc.abstractmethod @abc.abstractmethod

View file

@ -42,20 +42,17 @@ SQLQueryParameters = Union[Sequence[Any], Mapping[str, Any]]
class Cursor(Protocol): class Cursor(Protocol):
def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any: def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any: ...
...
def executemany(self, sql: str, parameters: Sequence[SQLQueryParameters]) -> Any: def executemany(
... self, sql: str, parameters: Sequence[SQLQueryParameters]
) -> Any: ...
def fetchone(self) -> Optional[Tuple]: def fetchone(self) -> Optional[Tuple]: ...
...
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ...
...
def fetchall(self) -> List[Tuple]: def fetchall(self) -> List[Tuple]: ...
...
@property @property
def description( def description(
@ -70,36 +67,28 @@ class Cursor(Protocol):
def rowcount(self) -> int: def rowcount(self) -> int:
return 0 return 0
def __iter__(self) -> Iterator[Tuple]: def __iter__(self) -> Iterator[Tuple]: ...
...
def close(self) -> None: def close(self) -> None: ...
...
class Connection(Protocol): class Connection(Protocol):
def cursor(self) -> Cursor: def cursor(self) -> Cursor: ...
...
def close(self) -> None: def close(self) -> None: ...
...
def commit(self) -> None: def commit(self) -> None: ...
...
def rollback(self) -> None: def rollback(self) -> None: ...
...
def __enter__(self) -> "Connection": def __enter__(self) -> "Connection": ...
...
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> Optional[bool]: ) -> Optional[bool]: ...
...
class DBAPI2Module(Protocol): class DBAPI2Module(Protocol):
@ -129,24 +118,20 @@ class DBAPI2Module(Protocol):
# explain why this is necessary for safety. TL;DR: we shouldn't be able to write # explain why this is necessary for safety. TL;DR: we shouldn't be able to write
# to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 . # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 .
@property @property
def Warning(self) -> Type[Exception]: def Warning(self) -> Type[Exception]: ...
...
@property @property
def Error(self) -> Type[Exception]: def Error(self) -> Type[Exception]: ...
...
# Errors are divided into `InterfaceError`s (something went wrong in the database # Errors are divided into `InterfaceError`s (something went wrong in the database
# driver) and `DatabaseError`s (something went wrong in the database). These are # driver) and `DatabaseError`s (something went wrong in the database). These are
# both subclasses of `Error`, but we can't currently express this in type # both subclasses of `Error`, but we can't currently express this in type
# annotations due to https://github.com/python/mypy/issues/8397 # annotations due to https://github.com/python/mypy/issues/8397
@property @property
def InterfaceError(self) -> Type[Exception]: def InterfaceError(self) -> Type[Exception]: ...
...
@property @property
def DatabaseError(self) -> Type[Exception]: def DatabaseError(self) -> Type[Exception]: ...
...
# Everything below is a subclass of `DatabaseError`. # Everything below is a subclass of `DatabaseError`.
@ -155,8 +140,7 @@ class DBAPI2Module(Protocol):
# - An invalid date time was provided. # - An invalid date time was provided.
# - A string contained a null code point. # - A string contained a null code point.
@property @property
def DataError(self) -> Type[Exception]: def DataError(self) -> Type[Exception]: ...
...
# Roughly: something went wrong in the database, but it's not within the application # Roughly: something went wrong in the database, but it's not within the application
# programmer's control. Examples: # programmer's control. Examples:
@ -167,21 +151,18 @@ class DBAPI2Module(Protocol):
# - The database ran out of resources, such as storage, memory, connections, etc. # - The database ran out of resources, such as storage, memory, connections, etc.
# - The database encountered an error from the operating system. # - The database encountered an error from the operating system.
@property @property
def OperationalError(self) -> Type[Exception]: def OperationalError(self) -> Type[Exception]: ...
...
# Roughly: we've given the database data which breaks a rule we asked it to enforce. # Roughly: we've given the database data which breaks a rule we asked it to enforce.
# Examples: # Examples:
# - Stop, criminal scum! You violated the foreign key constraint # - Stop, criminal scum! You violated the foreign key constraint
# - Also check constraints, non-null constraints, etc. # - Also check constraints, non-null constraints, etc.
@property @property
def IntegrityError(self) -> Type[Exception]: def IntegrityError(self) -> Type[Exception]: ...
...
# Roughly: something went wrong within the database server itself. # Roughly: something went wrong within the database server itself.
@property @property
def InternalError(self) -> Type[Exception]: def InternalError(self) -> Type[Exception]: ...
...
# Roughly: the application did something silly that needs to be fixed. Examples: # Roughly: the application did something silly that needs to be fixed. Examples:
# - We don't have permissions to do something. # - We don't have permissions to do something.
@ -189,13 +170,11 @@ class DBAPI2Module(Protocol):
# - We tried to use a reserved name. # - We tried to use a reserved name.
# - We referred to a column that doesn't exist. # - We referred to a column that doesn't exist.
@property @property
def ProgrammingError(self) -> Type[Exception]: def ProgrammingError(self) -> Type[Exception]: ...
...
# Roughly: we've tried to do something that this database doesn't support. # Roughly: we've tried to do something that this database doesn't support.
@property @property
def NotSupportedError(self) -> Type[Exception]: def NotSupportedError(self) -> Type[Exception]: ...
...
# We originally wrote # We originally wrote
# def connect(self, *args, **kwargs) -> Connection: ... # def connect(self, *args, **kwargs) -> Connection: ...
@ -204,8 +183,7 @@ class DBAPI2Module(Protocol):
# psycopg2.connect doesn't have a mandatory positional argument. Instead, we use # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
# the following slightly unusual workaround. # the following slightly unusual workaround.
@property @property
def connect(self) -> Callable[..., Connection]: def connect(self) -> Callable[..., Connection]: ...
...
__all__ = ["Cursor", "Connection", "DBAPI2Module"] __all__ = ["Cursor", "Connection", "DBAPI2Module"]

View file

@ -57,6 +57,7 @@ class EventInternalMetadata:
(Added in synapse 0.99.0, so may be unreliable for events received before that) (Added in synapse 0.99.0, so may be unreliable for events received before that)
""" """
... ...
def get_send_on_behalf_of(self) -> Optional[str]: def get_send_on_behalf_of(self) -> Optional[str]:
"""Whether this server should send the event on behalf of another server. """Whether this server should send the event on behalf of another server.
This is used by the federation "send_join" API to forward the initial join This is used by the federation "send_join" API to forward the initial join
@ -65,6 +66,7 @@ class EventInternalMetadata:
returns a str with the name of the server this event is sent on behalf of. returns a str with the name of the server this event is sent on behalf of.
""" """
... ...
def need_to_check_redaction(self) -> bool: def need_to_check_redaction(self) -> bool:
"""Whether the redaction event needs to be rechecked when fetching """Whether the redaction event needs to be rechecked when fetching
from the database. from the database.
@ -76,6 +78,7 @@ class EventInternalMetadata:
due to auth rules, then this will always return false. due to auth rules, then this will always return false.
""" """
... ...
def is_soft_failed(self) -> bool: def is_soft_failed(self) -> bool:
"""Whether the event has been soft failed. """Whether the event has been soft failed.
@ -86,6 +89,7 @@ class EventInternalMetadata:
therefore not to current state). therefore not to current state).
""" """
... ...
def should_proactively_send(self) -> bool: def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and """Whether the event, if ours, should be sent to other clients and
servers. servers.
@ -94,6 +98,7 @@ class EventInternalMetadata:
can still explicitly fetch the event. can still explicitly fetch the event.
""" """
... ...
def is_redacted(self) -> bool: def is_redacted(self) -> bool:
"""Whether the event has been redacted. """Whether the event has been redacted.
@ -101,6 +106,7 @@ class EventInternalMetadata:
marked as redacted without needing to make another database call. marked as redacted without needing to make another database call.
""" """
... ...
def is_notifiable(self) -> bool: def is_notifiable(self) -> bool:
"""Whether this event can trigger a push notification""" """Whether this event can trigger a push notification"""
... ...

View file

@ -976,12 +976,12 @@ class StreamToken:
return attr.evolve(self, **{key.value: new_value}) return attr.evolve(self, **{key.value: new_value})
@overload @overload
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: ...
...
@overload @overload
def get_field(self, key: Literal[StreamKeyType.RECEIPT]) -> MultiWriterStreamToken: def get_field(
... self, key: Literal[StreamKeyType.RECEIPT]
) -> MultiWriterStreamToken: ...
@overload @overload
def get_field( def get_field(
@ -995,14 +995,12 @@ class StreamToken:
StreamKeyType.TYPING, StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.UN_PARTIAL_STATED_ROOMS,
], ],
) -> int: ) -> int: ...
...
@overload @overload
def get_field( def get_field(
self, key: StreamKeyType self, key: StreamKeyType
) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ) -> Union[int, RoomStreamToken, MultiWriterStreamToken]: ...
...
def get_field( def get_field(
self, key: StreamKeyType self, key: StreamKeyType

View file

@ -357,24 +357,21 @@ T4 = TypeVar("T4")
@overload @overload
def gather_results( def gather_results(
deferredList: Tuple[()], consumeErrors: bool = ... deferredList: Tuple[()], consumeErrors: bool = ...
) -> "defer.Deferred[Tuple[()]]": ) -> "defer.Deferred[Tuple[()]]": ...
...
@overload @overload
def gather_results( def gather_results(
deferredList: Tuple["defer.Deferred[T1]"], deferredList: Tuple["defer.Deferred[T1]"],
consumeErrors: bool = ..., consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1]]": ) -> "defer.Deferred[Tuple[T1]]": ...
...
@overload @overload
def gather_results( def gather_results(
deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"], deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
consumeErrors: bool = ..., consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2]]": ) -> "defer.Deferred[Tuple[T1, T2]]": ...
...
@overload @overload
@ -383,8 +380,7 @@ def gather_results(
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]" "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
], ],
consumeErrors: bool = ..., consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2, T3]]": ) -> "defer.Deferred[Tuple[T1, T2, T3]]": ...
...
@overload @overload
@ -396,8 +392,7 @@ def gather_results(
"defer.Deferred[T4]", "defer.Deferred[T4]",
], ],
consumeErrors: bool = ..., consumeErrors: bool = ...,
) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": ) -> "defer.Deferred[Tuple[T1, T2, T3, T4]]": ...
...
def gather_results( # type: ignore[misc] def gather_results( # type: ignore[misc]
@ -782,18 +777,15 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
@overload @overload
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": ...
...
@overload @overload
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": ...
...
@overload @overload
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: ...
...
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:

View file

@ -152,12 +152,10 @@ class ExpiringCache(Generic[KT, VT]):
return key in self._cache return key in self._cache
@overload @overload
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]: ...
...
@overload @overload
def get(self, key: KT, default: T) -> Union[VT, T]: def get(self, key: KT, default: T) -> Union[VT, T]: ...
...
def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]: def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try: try:

View file

@ -580,8 +580,7 @@ class LruCache(Generic[KT, VT]):
callbacks: Collection[Callable[[], None]] = ..., callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ..., update_metrics: bool = ...,
update_last_access: bool = ..., update_last_access: bool = ...,
) -> Optional[VT]: ) -> Optional[VT]: ...
...
@overload @overload
def cache_get( def cache_get(
@ -590,8 +589,7 @@ class LruCache(Generic[KT, VT]):
callbacks: Collection[Callable[[], None]] = ..., callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ..., update_metrics: bool = ...,
update_last_access: bool = ..., update_last_access: bool = ...,
) -> Union[T, VT]: ) -> Union[T, VT]: ...
...
@synchronized @synchronized
def cache_get( def cache_get(
@ -634,16 +632,14 @@ class LruCache(Generic[KT, VT]):
key: tuple, key: tuple,
default: Literal[None] = None, default: Literal[None] = None,
update_metrics: bool = True, update_metrics: bool = True,
) -> Union[None, Iterable[Tuple[KT, VT]]]: ) -> Union[None, Iterable[Tuple[KT, VT]]]: ...
...
@overload @overload
def cache_get_multi( def cache_get_multi(
key: tuple, key: tuple,
default: T, default: T,
update_metrics: bool = True, update_metrics: bool = True,
) -> Union[T, Iterable[Tuple[KT, VT]]]: ) -> Union[T, Iterable[Tuple[KT, VT]]]: ...
...
@synchronized @synchronized
def cache_get_multi( def cache_get_multi(
@ -728,12 +724,10 @@ class LruCache(Generic[KT, VT]):
return value return value
@overload @overload
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: ...
...
@overload @overload
def cache_pop(key: KT, default: T) -> Union[T, VT]: def cache_pop(key: KT, default: T) -> Union[T, VT]: ...
...
@synchronized @synchronized
def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:

View file

@ -50,8 +50,7 @@ class _SelfSlice(Sized, Protocol):
returned. returned.
""" """
def __getitem__(self: S, i: slice) -> S: def __getitem__(self: S, i: slice) -> S: ...
...
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:

View file

@ -177,9 +177,9 @@ class FederationRateLimiter:
clock=clock, config=config, metrics_name=metrics_name clock=clock, config=config, metrics_name=metrics_name
) )
self.ratelimiters: DefaultDict[ self.ratelimiters: DefaultDict[str, "_PerHostRatelimiter"] = (
str, "_PerHostRatelimiter" collections.defaultdict(new_limiter)
] = collections.defaultdict(new_limiter) )
with _rate_limiter_instances_lock: with _rate_limiter_instances_lock:
_rate_limiter_instances.add(self) _rate_limiter_instances.add(self)

View file

@ -129,9 +129,9 @@ async def filter_events_for_client(
retention_policies: Dict[str, RetentionPolicy] = {} retention_policies: Dict[str, RetentionPolicy] = {}
for room_id in room_ids: for room_id in room_ids:
retention_policies[ retention_policies[room_id] = (
room_id await storage.main.get_retention_policy_for_room(room_id)
] = await storage.main.get_retention_policy_for_room(room_id) )
def allowed(event: EventBase) -> Optional[EventBase]: def allowed(event: EventBase) -> Optional[EventBase]:
return _check_client_allowed_to_see_event( return _check_client_allowed_to_see_event(

View file

@ -495,9 +495,9 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.""" """A fake Redis server for pub/sub."""
def __init__(self) -> None: def __init__(self) -> None:
self._subscribers_by_channel: Dict[ self._subscribers_by_channel: Dict[bytes, Set["FakeRedisPubSubProtocol"]] = (
bytes, Set["FakeRedisPubSubProtocol"] defaultdict(set)
] = defaultdict(set) )
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None: def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE""" """A connection has called SUBSCRIBE"""

View file

@ -1222,9 +1222,9 @@ class RoomJoinTestCase(RoomBase):
""" """
# Register a dummy callback. Make it allow all room joins for now. # Register a dummy callback. Make it allow all room joins for now.
return_value: Union[ return_value: Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes] = (
Literal["NOT_SPAM"], Tuple[Codes, dict], Codes synapse.module_api.NOT_SPAM
] = synapse.module_api.NOT_SPAM )
async def user_may_join_room( async def user_may_join_room(
userid: str, userid: str,
@ -1664,9 +1664,9 @@ class RoomMessagesTestCase(RoomBase):
expected_fields: dict, expected_fields: dict,
) -> None: ) -> None:
class SpamCheck: class SpamCheck:
mock_return_value: Union[ mock_return_value: Union[str, bool, Codes, Tuple[Codes, JsonDict], bool] = (
str, bool, Codes, Tuple[Codes, JsonDict], bool "NOT_SPAM"
] = "NOT_SPAM" )
mock_content: Optional[JsonDict] = None mock_content: Optional[JsonDict] = None
async def check_event_for_spam( async def check_event_for_spam(

View file

@ -87,8 +87,7 @@ class RestHelper:
expect_code: Literal[200] = ..., expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ..., extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> str: ) -> str: ...
...
@overload @overload
def create_room_as( def create_room_as(
@ -100,8 +99,7 @@ class RestHelper:
expect_code: int = ..., expect_code: int = ...,
extra_content: Optional[Dict] = ..., extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> Optional[str]: ) -> Optional[str]: ...
...
def create_room_as( def create_room_as(
self, self,

View file

@ -337,15 +337,15 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly. expires old entries correctly.
""" """
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["1"] = (
"1" 100000
] = 100000 )
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["2"] = (
"2" 200000
] = 200000 )
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion["3"] = (
"3" 300000
] = 300000 )
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame # All entries within time frame

View file

@ -328,9 +328,11 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" (
if expect_to_contain f"expected '{query}' to match '{self.PHRASE}'"
else f"'{query}' unexpectedly matched '{self.PHRASE}'", if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'"
),
) )
self.assertEqual( self.assertEqual(
len(result["results"]), len(result["results"]),
@ -346,9 +348,11 @@ class MessageSearchTest(HomeserverTestCase):
self.assertEqual( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" (
if expect_to_contain f"expected '{query}' to match '{self.PHRASE}'"
else f"'{query}' unexpectedly matched '{self.PHRASE}'", if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'"
),
) )
self.assertEqual( self.assertEqual(
len(result["results"]), len(result["results"]),

View file

@ -109,8 +109,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type.""" """Extension to twisted.Failure, where the 'value' has a certain type."""
@property @property
def value(self) -> _ExcType: def value(self) -> _ExcType: ...
...
def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]: def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:

View file

@ -34,8 +34,7 @@ from tests import unittest
class UnblockFunction(Protocol): class UnblockFunction(Protocol):
def __call__(self, pump_reactor: bool = True) -> None: def __call__(self, pump_reactor: bool = True) -> None: ...
...
class LinearizerTestCase(unittest.TestCase): class LinearizerTestCase(unittest.TestCase):

View file

@ -121,13 +121,11 @@ def setupdb() -> None:
@overload @overload
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ...
...
@overload @overload
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ...
...
def default_config( def default_config(