0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 05:13:50 +01:00

Use inline type hints in handlers/ and rest/. (#10382)

This commit is contained in:
Jonathan de Jong 2021-07-16 19:22:36 +02:00 committed by GitHub
parent 36dc15412d
commit 98aec1cc9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 212 additions and 215 deletions

1
changelog.d/10382.misc Normal file
View file

@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.

View file

@ -38,10 +38,10 @@ class BaseHandler:
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
@ -55,12 +55,12 @@ class BaseHandler:
# Check whether ratelimiting room admin message redaction is enabled # Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config # by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction: if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter( self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second, rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count, burst_count=self.hs.config.rc_admin_redaction.burst_count,
) # type: Optional[Ratelimiter] )
else: else:
self.admin_redaction_ratelimiter = None self.admin_redaction_ratelimiter = None

View file

@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
to_key = RoomStreamToken(None, stream_ordering) to_key = RoomStreamToken(None, stream_ordering)
# Events that we've processed in this room # Events that we've processed in this room
written_events = set() # type: Set[str] written_events: Set[str] = set()
# We need to track gaps in the events stream so that we can then # We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track # write out the state at those events. We do this by keeping track
@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events # The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen # that have the unseen event in their prev_events, i.e. the unseen
# events "children". # events "children".
unseen_to_child_events = {} # type: Dict[str, Set[str]] unseen_to_child_events: Dict[str, Set[str]] = {}
# We fetch events in the room the user could see by fetching *all* # We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most

View file

@ -96,7 +96,7 @@ class ApplicationServicesHandler:
self.current_max, limit self.current_max, limit
) )
events_by_room = {} # type: Dict[str, List[EventBase]] events_by_room: Dict[str, List[EventBase]] = {}
for event in events: for event in events:
events_by_room.setdefault(event.room_id, []).append(event) events_by_room.setdefault(event.room_id, []).append(event)
@ -275,7 +275,7 @@ class ApplicationServicesHandler:
async def _handle_presence( async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]] self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]: ) -> List[JsonDict]:
events = [] # type: List[JsonDict] events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence" service, "presence"
@ -375,7 +375,7 @@ class ApplicationServicesHandler:
self, only_protocol: Optional[str] = None self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]: ) -> Dict[str, JsonDict]:
services = self.store.get_app_services() services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]] protocols: Dict[str, List[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes # Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:

View file

@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs) inst = auth_checker_class(hs)
if inst.is_enabled(): if inst.is_enabled():
@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
# A mapping of user ID to extra attributes to include in the login # A mapping of user ID to extra attributes to include in the login
# response. # response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
async def validate_user_via_ui_auth( async def validate_user_via_ui_auth(
self, self,
@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows. all the stages in any of the permitted flows.
""" """
sid = None # type: Optional[str] sid: Optional[str] = None
authdict = clientdict.pop("auth", {}) authdict = clientdict.pop("auth", {})
if "session" in authdict: if "session" in authdict:
sid = authdict["session"] sid = authdict["session"]
@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
) )
# check auth type currently being presented # check auth type currently being presented
errordict = {} # type: Dict[str, Any] errordict: Dict[str, Any] = {}
if "type" in authdict: if "type" in authdict:
login_type = authdict["type"] # type: str login_type: str = authdict["type"]
try: try:
result = await self._check_auth_dict(authdict, clientip) result = await self._check_auth_dict(authdict, clientip)
if result: if result:
@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms, LoginType.TERMS: self._get_params_terms,
} }
params = {} # type: Dict[str, Any] params: Dict[str, Any] = {}
for f in public_flows: for f in public_flows:
for stage in f: for stage in f:
@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
except StoreError: except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
user_id_to_verify = await self.get_session_data( user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str )
idps = await self.hs.get_sso_handler().get_identity_providers_for_user( idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify user_id_to_verify

View file

@ -171,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes. # Iterate through the nodes and pull out the user and any extra attributes.
user = None user = None
attributes = {} # type: Dict[str, List[Optional[str]]] attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]: for child in root[0]:
if child.tag.endswith("user"): if child.tag.endswith("user"):
user = child.text user = child.text

View file

@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id user_id
) )
hosts = set() # type: Set[str] hosts: Set[str] = set()
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name) hosts.discard(self.server_name)
@ -613,20 +613,20 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list") self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = ( self._pending_updates: Dict[
{} str, List[Tuple[str, str, Iterable[str], JsonDict]]
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] ] = {}
# Recently seen stream ids. We don't bother keeping these in the DB, # Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious # but they're useful to have them about to reduce the number of spurious
# resyncs. # resyncs.
self._seen_updates = ExpiringCache( self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu", cache_name="device_update_edu",
clock=self.clock, clock=self.clock,
max_len=10000, max_len=10000,
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
iterable=True, iterable=True,
) # type: ExpiringCache[str, Set[str]] )
# Attempt to resync out of sync device lists every 30s. # Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False self._resync_retry_in_progress = False
@ -755,7 +755,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str] seen_updates: Set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

View file

@ -203,7 +203,7 @@ class DeviceMessageHandler:
log_kv({"number_of_to_device_messages": len(messages)}) log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device. # Ratelimit local cross-user key requests by the sending device.
if ( if (

View file

@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
async def get_association(self, room_alias: RoomAlias) -> JsonDict: async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias( result: Optional[
room_alias RoomAliasMapping
) # type: Optional[RoomAliasMapping] ] = await self.get_association_from_room_alias(room_alias)
if result: if result:
room_id = result.room_id room_id = result.room_id

View file

@ -115,9 +115,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time. the number of in-flight queries at a time.
""" """
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get( device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {} "device_keys", {}
) # type: Dict[str, Iterable[str]] )
# separate users by domain. # separate users by domain.
# make a map from domain to user_id to device_ids # make a map from domain to user_id to device_ids
@ -136,7 +136,7 @@ class E2eKeysHandler:
# First get local devices. # First get local devices.
# A map of destination -> failure response. # A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict] failures: Dict[str, JsonDict] = {}
results = {} results = {}
if local_query: if local_query:
local_result = await self.query_local_devices(local_query) local_result = await self.query_local_devices(local_query)
@ -151,11 +151,9 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache. # Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs. # A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = ( remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries: if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]] query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items(): for user_id, device_ids in remote_queries.items():
if device_ids: if device_ids:
query_list.extend( query_list.extend(
@ -362,9 +360,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details A map from user_id -> device_id -> device details
""" """
set_tag("local_query", query) set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]] local_query: List[Tuple[str, Optional[str]]] = []
result_dict = {} # type: Dict[str, Dict[str, dict]] result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)): if not self.is_mine(UserID.from_string(user_id)):
@ -402,9 +400,9 @@ class E2eKeysHandler:
self, query_body: Dict[str, Dict[str, Optional[List[str]]]] self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict: ) -> JsonDict:
"""Handle a device key query from a federated server""" """Handle a device key query from a federated server"""
device_keys_query = query_body.get( device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {} "device_keys", {}
) # type: Dict[str, Optional[List[str]]] )
res = await self.query_local_devices(device_keys_query) res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res} ret = {"device_keys": res}
@ -421,8 +419,8 @@ class E2eKeysHandler:
async def claim_one_time_keys( async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict: ) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]] local_query: List[Tuple[str, str, str]] = []
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items(): for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
@ -439,8 +437,8 @@ class E2eKeysHandler:
results = await self.store.claim_e2e_one_time_keys(local_query) results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key. # A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
failures = {} # type: Dict[str, JsonDict] failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_str in keys.items(): for key_id, json_str in keys.items():
@ -768,8 +766,8 @@ class E2eKeysHandler:
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] # type: List[SignatureListItem] signature_list: List["SignatureListItem"] = []
failures = {} # type: Dict[str, Dict[str, JsonDict]] failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@ -930,8 +928,8 @@ class E2eKeysHandler:
Raises: Raises:
SynapseError: if the input is malformed SynapseError: if the input is malformed
""" """
signature_list = [] # type: List[SignatureListItem] signature_list: List["SignatureListItem"] = []
failures = {} # type: Dict[str, Dict[str, JsonDict]] failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures: if not signatures:
return signature_list, failures return signature_list, failures
@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key") self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled. # user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
async def incoming_signing_key_update( async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict self, origin: str, edu_content: JsonDict
@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates # This can happen since we batch updates
return return
device_ids = [] # type: List[str] device_ids: List[str] = []
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)

View file

@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently # When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users. # joined room, we need to send down presence for those users.
to_add = [] # type: List[JsonDict] to_add: List[JsonDict] = []
for event in events: for event in events:
if not isinstance(event, EventBase): if not isinstance(event, EventBase):
continue continue
@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = await self.store.get_users_in_room( users: Iterable[str] = await self.store.get_users_in_room(
event.room_id event.room_id
) # type: Iterable[str] )
else: else:
users = [event.state_key] users = [event.state_key]

View file

@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up. # When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples. # For each room, a list of (pdu, origin) tuples.
self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu") self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_backfill = Linearizer("room_backfill") self._room_backfill = Linearizer("room_backfill")
@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: List[StateMap[str]] state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it. # we don't need this any more, let's delete it.
del ours del ours
@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
# exact key to expect. Otherwise check it matches any key we # exact key to expect. Otherwise check it matches any key we
# have for that device. # have for that device.
current_keys = [] # type: Container[str] current_keys: Container[str] = []
if device: if device:
keys = device.get("keys", {}).get("keys", {}) keys = device.get("keys", {}).get("keys", {})
@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN if e_type == EventTypes.Member and event.membership == Membership.JOIN
] ]
joined_domains = {} # type: Dict[str, int] joined_domains: Dict[str, int] = {}
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = get_domain_from_id(u) dom = get_domain_from_id(u)
@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
event_map = {} # type: Dict[str, EventBase] event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str): async def get_event(event_id: str):
with nested_logging_context(event_id): with nested_logging_context(event_id):
@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
# Ask the remote server to create a valid knock event for us. Once received, # Ask the remote server to create a valid knock event for us. Once received,
# we sign the event # we sign the event
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
origin, event, event_format_version = await self._make_and_verify_event( origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
) )
@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
state_sets_d = await self.state_store.get_state_groups( state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state) state_sets.append(state)
current_states = await self.state_handler.resolve_events( current_states = await self.state_handler.resolve_events(
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = { current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items() k: e.event_id for k, e in current_states.items()
} # type: StateMap[str] }
else: else:
current_state_ids = await self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids event.room_id, latest_event_ids=extrem_ids
@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
""" """
# exclude the state key of the new event from the current_state in the context. # exclude the state key of the new event from the current_state in the context.
if event.is_state(): if event.is_state():
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]] event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else: else:
event_key = None event_key = None
state_updates = { state_updates = {
@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content) logger.debug("Checking auth on event %r", event.content)
last_exception = None # type: Optional[Exception] last_exception: Optional[Exception] = None
# for each public key in the 3pid invite event # for each public key in the 3pid invite event
for public_key_object in event_auth.get_public_keys(invite_event): for public_key_object in event_auth.get_public_keys(invite_event):

View file

@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
async def bulk_get_publicised_groups( async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict: ) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]] destinations: Dict[str, Set[str]] = {}
local_users = set() local_users = set()
for user_id in user_ids: for user_id in user_ids:
@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local") raise SynapseError(400, "Some user_ids are not local")
results = {} results = {}
failed_results = [] # type: List[str] failed_results: List[str] = []
for destination, dest_user_ids in destinations.items(): for destination, dest_user_ids in destinations.items():
try: try:
r = await self.transport_client.bulk_get_publicised_groups( r = await self.transport_client.bulk_get_publicised_groups(

View file

@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = ResponseCache( self.snapshot_cache: ResponseCache[
hs.get_clock(), "initial_sync_cache" Tuple[
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] str,
Optional[StreamToken],
Optional[StreamToken],
str,
Optional[int],
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state

View file

@ -81,7 +81,7 @@ class MessageHandler:
# The scheduled call to self._expire_event. None if no call is currently # The scheduled call to self._expire_event. None if no call is currently
# scheduled. # scheduled.
self._scheduled_expiry = None # type: Optional[IDelayedCall] self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app: if not hs.config.worker_app:
run_as_background_process( run_as_background_process(
@ -196,9 +196,7 @@ class MessageHandler:
room_state_events = await self.state_store.get_state_for_events( room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter [event.event_id], state_filter=state_filter
) )
room_state = room_state_events[ room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
event.event_id
] # type: Mapping[Any, EventBase]
else: else:
raise AuthError( raise AuthError(
403, 403,
@ -421,9 +419,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator() self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = ( self.third_party_event_rules: "ThirdPartyEventRules" = (
self.hs.get_third_party_event_rules() self.hs.get_third_party_event_rules()
) # type: ThirdPartyEventRules )
self._block_events_without_consent_error = ( self._block_events_without_consent_error = (
self.config.block_events_without_consent_error self.config.block_events_without_consent_error
@ -440,7 +438,7 @@ class EventCreationHandler:
# #
# map from room id to time-of-last-attempt. # map from room id to time-of-last-attempt.
# #
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int] self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent. # The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold self._dummy_events_threshold = hs.config.dummy_events_threshold
@ -465,9 +463,7 @@ class EventCreationHandler:
# Stores the state groups we've recently added to the joined hosts # Stores the state groups we've recently added to the joined hosts
# external cache. Note that the timeout must be significantly less than # external cache. Note that the timeout must be significantly less than
# the TTL on the external cache. # the TTL on the external cache.
self._external_cache_joined_hosts_updates = ( self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
None
) # type: Optional[ExpiringCache]
if self._external_cache.is_enabled(): if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache( self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates", "_external_cache_joined_hosts_updates",
@ -1299,7 +1295,7 @@ class EventCreationHandler:
# Validate a newly added alias or newly added alt_aliases. # Validate a newly added alias or newly added alt_aliases.
original_alias = None original_alias = None
original_alt_aliases = [] # type: List[str] original_alt_aliases: List[str] = []
original_event_id = event.unsigned.get("replaces_state") original_event_id = event.unsigned.get("replaces_state")
if original_event_id: if original_event_id:

View file

@ -105,9 +105,9 @@ class OidcHandler:
assert provider_confs assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs) self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = { self._providers: Dict[str, "OidcProvider"] = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
} # type: Dict[str, OidcProvider] }
async def load_metadata(self) -> None: async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint. """Validate the config and load the metadata from the remote endpoint.
@ -178,7 +178,7 @@ class OidcHandler:
# are two. # are two.
for cookie_name, _ in _SESSION_COOKIES: for cookie_name, _ in _SESSION_COOKIES:
session = request.getCookie(cookie_name) # type: Optional[bytes] session: Optional[bytes] = request.getCookie(cookie_name)
if session is not None: if session is not None:
break break
else: else:
@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator self._token_generator = token_generator
self._config = provider self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url: str = hs.config.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl. # Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set. # We'll insert this into the Path= parameter of any session cookies we set.
@ -290,7 +290,7 @@ class OidcProvider:
self._scopes = provider.scopes self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method self._user_profile_method = provider.user_profile_method
client_secret = None # type: Union[None, str, JwtClientSecret] client_secret: Optional[Union[str, JwtClientSecret]] = None
if provider.client_secret: if provider.client_secret:
client_secret = provider.client_secret client_secret = provider.client_secret
elif provider.client_secret_jwt_key: elif provider.client_secret_jwt_key:
@ -305,7 +305,7 @@ class OidcProvider:
provider.client_id, provider.client_id,
client_secret, client_secret,
provider.client_auth_method, provider.client_auth_method,
) # type: ClientAuth )
self._client_auth_method = provider.client_auth_method self._client_auth_method = provider.client_auth_method
# cache of metadata for the identity provider (endpoint uris, mostly). This is # cache of metadata for the identity provider (endpoint uris, mostly). This is
@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str self._server_name: str = hs.config.server_name
# identifier for the external_ids table # identifier for the external_ids table
self.idp_id = provider.idp_id self.idp_id = provider.idp_id
@ -1381,7 +1381,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "": if display_name == "":
display_name = None display_name = None
emails = [] # type: List[str] emails: List[str] = []
email = render_template_field(self._config.email_template) email = render_template_field(self._config.email_template)
if email: if email:
emails.append(email) emails.append(email)
@ -1391,7 +1391,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
) )
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str] extras: Dict[str, str] = {}
for key, template in self._config.extra_attributes.items(): for key, template in self._config.extra_attributes.items():
try: try:
extras[key] = template.render(user=userinfo).strip() extras[key] = template.render(user=userinfo).strip()

View file

@ -81,9 +81,9 @@ class PaginationHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set() # type: Set[str] self._purges_in_progress_by_room: Set[str] = set()
# map from purge id to PurgeStatus # map from purge id to PurgeStatus
self._purges_by_id = {} # type: Dict[str, PurgeStatus] self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime

View file

@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by user id. # The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false. # Empty if _presence_enabled is false.
self._user_to_num_current_syncs = {} # type: Dict[str, int] self._user_to_num_current_syncs: Dict[str, int] = {}
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id() self.instance_id = hs.get_instance_id()
# user_id -> last_sync_ms. Lists the users that have stopped syncing but # user_id -> last_sync_ms. Lists the users that have stopped syncing but
# we haven't notified the presence writer of that yet # we haven't notified the presence writer of that yet
self.users_going_offline = {} # type: Dict[str, int] self.users_going_offline: Dict[str, int] = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
# Set of users who have presence in the `user_to_current_state` that # Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted # have not yet been persisted
self.unpersisted_users_changes = set() # type: Set[str] self.unpersisted_users_changes: Set[str] = set()
hs.get_reactor().addSystemEventTrigger( hs.get_reactor().addSystemEventTrigger(
"before", "before",
@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While # Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline. # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} # type: Dict[str, int] self.user_to_num_current_syncs: Dict[str, int] = {}
# Keeps track of the number of *ongoing* syncs on other processes. # Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never # While any sync is ongoing on another process the user will never
@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
# we assume that all the sync requests on that process have stopped. # we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of # Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated. # process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]] self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
self.external_process_last_updated_ms = {} # type: Dict[str, int] self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@ -1581,9 +1581,7 @@ class PresenceEventSource:
# The set of users that we're interested in and that have had a presence update. # The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end. # We'll actually pull the presence updates for these users at the end.
interested_and_updated_users = ( interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
set()
) # type: Union[Set[str], FrozenSet[str]]
if from_key: if from_key:
# First get all users that have had a presence update # First get all users that have had a presence update
@ -1950,8 +1948,8 @@ async def get_interested_parties(
A 2-tuple of `(room_ids_to_states, users_to_states)`, A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]` with each item being a dict of `entity_name` -> `[UserPresenceState]`
""" """
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
users_to_states = {} # type: Dict[str, List[UserPresenceState]] users_to_states: Dict[str, List[UserPresenceState]] = {}
for state in states: for state in states:
room_ids = await store.get_rooms_for_user(state.user_id) room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids: for room_id in room_ids:
@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states # stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states # for efficiency, and remote workers will already have the full states
# cached. # cached.
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]] self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
self._next_id = 1 self._next_id = 1
# Map from instance name to current token # Map from instance name to current token
self._current_tokens = {} # type: Dict[str, int] self._current_tokens: Dict[str, int] = {}
if self._queue_presence_updates: if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
# handle the case where `from_token` stream ID has already been dropped. # handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token + 1 - self._next_id, -len(self._queue)) start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
to_send = [] # type: List[Tuple[int, Tuple[str, str]]] to_send: List[Tuple[int, Tuple[str, str]]] = []
limited = False limited = False
new_id = upto_token new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]: for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
if not self._federation: if not self._federation:
return return
hosts_to_users = {} # type: Dict[str, Set[str]] hosts_to_users: Dict[str, Set[str]] = {}
for row in rows: for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id) hosts_to_users.setdefault(row.destination, set()).add(row.user_id)

View file

@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
) )
displayname_to_set = new_displayname # type: Optional[str] displayname_to_set: Optional[str] = new_displayname
if new_displayname == "": if new_displayname == "":
displayname_to_set = None displayname_to_set = None
@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
) )
avatar_url_to_set = new_avatar_url # type: Optional[str] avatar_url_to_set: Optional[str] = new_avatar_url
if new_avatar_url == "": if new_avatar_url == "":
avatar_url_to_set = None avatar_url_to_set = None

View file

@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.""" """Takes a list of receipts, stores them and informs the notifier."""
min_batch_id = None # type: Optional[int] min_batch_id: Optional[int] = None
max_batch_id = None # type: Optional[int] max_batch_id: Optional[int] = None
for receipt in receipts: for receipt in receipts:
res = await self.store.insert_receipt( res = await self.store.insert_receipt(

View file

@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
self.config = hs.config self.config = hs.config
# Room state based off defined presets # Room state based off defined presets
self._presets_dict = { self._presets_dict: Dict[str, Dict[str, Any]] = {
RoomCreationPreset.PRIVATE_CHAT: { RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": HistoryVisibility.SHARED, "history_visibility": HistoryVisibility.SHARED,
@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False, "guest_can_join": False,
"power_level_content_override": {}, "power_level_content_override": {},
}, },
} # type: Dict[str, Dict[str, Any]] }
# Modify presets to selectively enable encryption by default per homeserver config # Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items(): for preset_name, preset_config in self._presets_dict.items():
@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
# If a user tries to update the same room multiple times in quick # If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to # succession, only process the first attempt and return its result to
# subsequent requests # subsequent requests
self._upgrade_response_cache = ResponseCache( self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]] )
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules() self.third_party_event_rules = hs.get_third_party_event_rules()
@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
if not await self.spam_checker.user_may_create_room(user_id): if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
creation_content = { creation_content: JsonDict = {
"room_version": new_room_version.identifier, "room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
} # type: JsonDict }
# Check if old room was non-federatable # Check if old room was non-federatable
@ -936,7 +936,7 @@ class RoomCreationHandler(BaseHandler):
etype=EventTypes.PowerLevels, content=pl_content etype=EventTypes.PowerLevels, content=pl_content
) )
else: else:
power_level_content = { power_level_content: JsonDict = {
"users": {creator_id: 100}, "users": {creator_id: 100},
"users_default": 0, "users_default": 0,
"events": { "events": {
@ -955,7 +955,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50, "kick": 50,
"redact": 50, "redact": 50,
"invite": 50, "invite": 50,
} # type: JsonDict }
if config["original_invitees_have_ops"]: if config["original_invitees_have_ops"]:
for invitee in invite_list: for invitee in invite_list:

View file

@ -47,12 +47,12 @@ class RoomListHandler(BaseHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache( self.response_cache: ResponseCache[
hs.get_clock(), "room_list" Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]] ] = ResponseCache(hs.get_clock(), "room_list")
self.remote_response_cache = ResponseCache( self.remote_response_cache: ResponseCache[
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
async def get_local_public_room_list( async def get_local_public_room_list(
self, self,
@ -139,10 +139,10 @@ class RoomListHandler(BaseHandler):
if since_token: if since_token:
batch_token = RoomListNextBatch.from_token(since_token) batch_token = RoomListNextBatch.from_token(since_token)
bounds = ( bounds: Optional[Tuple[int, str]] = (
batch_token.last_joined_members, batch_token.last_joined_members,
batch_token.last_room_id, batch_token.last_room_id,
) # type: Optional[Tuple[int, str]] )
forwards = batch_token.direction_is_forward forwards = batch_token.direction_is_forward
has_batch_token = True has_batch_token = True
else: else:
@ -182,7 +182,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results] results = [build_room_entry(r) for r in results]
response = {} # type: JsonDict response: JsonDict = {}
num_results = len(results) num_results = len(results)
if limit is not None: if limit is not None:
more_to_come = num_results == probing_limit more_to_come = num_results == probing_limit

View file

@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
self.unstable_idp_brand = None self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self) self._sso_handler.register_identity_provider(self)
@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
return username return username
MXID_MAPPER_MAP = { MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
"hexencode": map_username_to_mxid_localpart, "hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid, "dotreplace": dot_replace_for_mxid,
} # type: Dict[str, Callable[[str], str]] }
@attr.s @attr.s

View file

@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms # If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well # are from an upgraded room, and search their contents as well
if search_filter.rooms: if search_filter.rooms:
historical_room_ids = [] # type: List[str] historical_room_ids: List[str] = []
for room_id in search_filter.rooms: for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist # Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id) ids = await self.get_old_rooms_from_upgraded_room(room_id)
@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event rank_map = {} # event_id -> rank of event
allowed_events = [] allowed_events = []
# Holds result of grouping by room, if applicable # Holds result of grouping by room, if applicable
room_groups = {} # type: Dict[str, JsonDict] room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable # Holds result of grouping by sender, if applicable
sender_group = {} # type: Dict[str, JsonDict] sender_group: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists # Holds the next_batch for the entire result set if one of those exists
global_next_batch = None global_next_batch = None
@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id) s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
room_events = [] # type: List[EventBase] room_events: List[EventBase] = []
i = 0 i = 0
pagination_token = batch_token pagination_token = batch_token

View file

@ -90,14 +90,14 @@ class SpaceSummaryHandler:
room_queue = deque((_RoomQueueEntry(room_id, ()),)) room_queue = deque((_RoomQueueEntry(room_id, ()),))
# rooms we have already processed # rooms we have already processed
processed_rooms = set() # type: Set[str] processed_rooms: Set[str] = set()
# events we have already processed. We don't necessarily have their event ids, # events we have already processed. We don't necessarily have their event ids,
# so instead we key on (room id, state key) # so instead we key on (room id, state key)
processed_events = set() # type: Set[Tuple[str, str]] processed_events: Set[Tuple[str, str]] = set()
rooms_result = [] # type: List[JsonDict] rooms_result: List[JsonDict] = []
events_result = [] # type: List[JsonDict] events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS: while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft() queue_entry = room_queue.popleft()
@ -272,10 +272,10 @@ class SpaceSummaryHandler:
# the set of rooms that we should not walk further. Initialise it with the # the set of rooms that we should not walk further. Initialise it with the
# excluded-rooms list; we will add other rooms as we process them so that # excluded-rooms list; we will add other rooms as we process them so that
# we do not loop. # we do not loop.
processed_rooms = set(exclude_rooms) # type: Set[str] processed_rooms: Set[str] = set(exclude_rooms)
rooms_result = [] # type: List[JsonDict] rooms_result: List[JsonDict] = []
events_result = [] # type: List[JsonDict] events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS: while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft() room_id = room_queue.popleft()
@ -353,7 +353,7 @@ class SpaceSummaryHandler:
max_children = MAX_ROOMS_PER_SPACE max_children = MAX_ROOMS_PER_SPACE
now = self._clock.time_msec() now = self._clock.time_msec()
events_result = [] # type: List[JsonDict] events_result: List[JsonDict] = []
for edge_event in itertools.islice(child_events, max_children): for edge_event in itertools.islice(child_events, max_children):
events_result.append( events_result.append(
await self._event_serializer.serialize_event( await self._event_serializer.serialize_event(

View file

@ -202,10 +202,10 @@ class SsoHandler:
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
# a map from session id to session data # a map from session id to session data
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
# map from idp_id to SsoIdentityProvider # map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] self._identity_providers: Dict[str, SsoIdentityProvider] = {}
self._consent_at_registration = hs.config.consent.user_consent_at_registration self._consent_at_registration = hs.config.consent.user_consent_at_registration
@ -296,7 +296,7 @@ class SsoHandler:
) )
# if the client chose an IdP, use that # if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider] idp: Optional[SsoIdentityProvider] = None
if idp_id: if idp_id:
idp = self._identity_providers.get(idp_id) idp = self._identity_providers.get(idp_id)
if not idp: if not idp:
@ -669,9 +669,9 @@ class SsoHandler:
remote_user_id, remote_user_id,
) )
user_id_to_verify = await self._auth_handler.get_session_data( user_id_to_verify: str = await self._auth_handler.get_session_data(
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str )
if not user_id: if not user_id:
logger.warning( logger.warning(
@ -793,7 +793,7 @@ class SsoHandler:
session.use_display_name = use_display_name session.use_display_name = use_display_name
emails_from_idp = set(session.emails) emails_from_idp = set(session.emails)
filtered_emails = set() # type: Set[str] filtered_emails: Set[str] = set()
# we iterate through the list rather than just building a set conjunction, so # we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses # that we can log attempts to use unknown addresses

View file

@ -49,7 +49,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream # The current position in the current_state_delta stream
self.pos = None # type: Optional[int] self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time # Guard to ensure we only process deltas one at a time
self._is_processing = False self._is_processing = False
@ -131,10 +131,10 @@ class StatsHandler:
mapping from room/user ID to changes in the various fields. mapping from room/user ID to changes in the various fields.
""" """
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] room_to_stats_deltas: Dict[str, CounterType[str]] = {}
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] user_to_stats_deltas: Dict[str, CounterType[str]] = {}
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas: for delta in deltas:
typ = delta["type"] typ = delta["type"]
@ -164,7 +164,7 @@ class StatsHandler:
) )
continue continue
event_content = {} # type: JsonDict event_content: JsonDict = {}
if event_id is not None: if event_id is not None:
event = await self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)

View file

@ -278,12 +278,14 @@ class SyncHandler:
self.state_store = self.storage.state self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id) # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache( self.lazy_loaded_members_cache: ExpiringCache[
Tuple[str, Optional[str]], LruCache[str, str]
] = ExpiringCache(
"lazy_loaded_members_cache", "lazy_loaded_members_cache",
self.clock, self.clock,
max_len=0, max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]] )
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
@ -440,7 +442,7 @@ class SyncHandler:
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
ephemeral_by_room = {} # type: JsonDict ephemeral_by_room: JsonDict = {}
for event in typing: for event in typing:
# we want to exclude the room_id from the event, but modifying the # we want to exclude the room_id from the event, but modifying the
@ -502,7 +504,7 @@ class SyncHandler:
# We check if there are any state events, if there are then we pass # We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to # all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline # ensure that we always include current state in the timeline
current_state_ids = frozenset() # type: FrozenSet[str] current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents): if any(e.is_state() for e in recents):
current_state_ids_map = await self.store.get_current_state_ids( current_state_ids_map = await self.store.get_current_state_ids(
room_id room_id
@ -783,9 +785,9 @@ class SyncHandler:
def get_lazy_loaded_members_cache( def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]] self, cache_key: Tuple[str, Optional[str]]
) -> LruCache[str, str]: ) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get( cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
cache_key cache_key
) # type: Optional[LruCache[str, str]] )
if cache is None: if cache is None:
logger.debug("creating LruCache for %r", cache_key) logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE) cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@ -984,7 +986,7 @@ class SyncHandler:
if t[0] == EventTypes.Member: if t[0] == EventTypes.Member:
cache.set(t[1], event_id) cache.set(t[1], event_id)
state = {} # type: Dict[str, EventBase] state: Dict[str, EventBase] = {}
if state_ids: if state_ids:
state = await self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
@ -1088,8 +1090,8 @@ class SyncHandler:
logger.debug("Fetching OTK data") logger.debug("Fetching OTK data")
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict one_time_key_counts: JsonDict = {}
unused_fallback_key_types = [] # type: List[str] unused_fallback_key_types: List[str] = []
if device_id: if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys( one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
@ -1437,7 +1439,7 @@ class SyncHandler:
) )
if block_all_room_ephemeral: if block_all_room_ephemeral:
ephemeral_by_room = {} # type: Dict[str, List[JsonDict]] ephemeral_by_room: Dict[str, List[JsonDict]] = {}
else: else:
now_token, ephemeral_by_room = await self.ephemeral_by_room( now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder, sync_result_builder,
@ -1468,7 +1470,7 @@ class SyncHandler:
# If there is ignored users account data and it matches the proper type, # If there is ignored users account data and it matches the proper type,
# then use it. # then use it.
ignored_users = frozenset() # type: FrozenSet[str] ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data: if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {}) ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict): if isinstance(ignored_users_data, dict):
@ -1586,7 +1588,7 @@ class SyncHandler:
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
) )
mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in rooms_changed: for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@ -1722,7 +1724,7 @@ class SyncHandler:
# This is all screaming out for a refactor, as the logic here is # This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous. # subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership(): if leave_event.internal_metadata.is_out_of_band_membership():
batch_events = [leave_event] # type: Optional[List[EventBase]] batch_events: Optional[List[EventBase]] = [leave_event]
else: else:
batch_events = None batch_events = None
@ -1971,7 +1973,7 @@ class SyncHandler:
room_id, batch, sync_config, since_token, now_token, full_state=full_state room_id, batch, sync_config, since_token, now_token, full_state=full_state
) )
summary = {} # type: Optional[JsonDict] summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading # we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form # members (as the client otherwise doesn't have enough info to form
@ -1995,7 +1997,7 @@ class SyncHandler:
) )
if room_builder.rtype == "joined": if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, int] unread_notifications: Dict[str, int] = {}
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,

View file

@ -68,11 +68,11 @@ class FollowerTypingHandler:
) )
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} # type: Dict[str, int] self._room_serials: Dict[str, int] = {}
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} # type: Dict[str, Set[str]] self._room_typing: Dict[str, Set[str]] = {}
self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0 self._latest_room_serial = 0
@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
# clock time we expect to stop # clock time we expect to stop
self._member_typing_until = {} # type: Dict[RoomMember, int] self._member_typing_until: Dict[RoomMember, int] = {}
# caches which room_ids changed at which serials # caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache( self._typing_stream_change_cache = StreamChangeCache(
@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( changed_rooms: Optional[
last_id Iterable[str]
) # type: Optional[Iterable[str]] ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
if changed_rooms is None: if changed_rooms is None:
changed_rooms = self._room_serials changed_rooms = self._room_serials

View file

@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream # The current position in the current_state_delta stream
self.pos = None # type: Optional[int] self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time # Guard to ensure we only process deltas one at a time
self._is_processing = False self._is_processing = False

View file

@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
# Get the room ID from the identifier. # Get the room ID from the identifier.
try: try:
remote_room_hosts = [ remote_room_hosts: Optional[List[str]] = [
x.decode("ascii") for x in request.args[b"server_name"] x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]] ]
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id( room_id, remote_room_hosts = await self.resolve_room_id(
@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter = Filter( event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
json_decoder.decode(filter_json)
) # type: Optional[Filter]
else: else:
event_filter = None event_filter = None

View file

@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self.nonces = {} # type: Dict[str, int] self.nonces: Dict[str, int] = {}
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self):

View file

@ -121,7 +121,7 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = { sso_flow: JsonDict = {
"type": LoginRestServlet.SSO_TYPE, "type": LoginRestServlet.SSO_TYPE,
"identity_providers": [ "identity_providers": [
_get_auth_flow_dict_for_idp( _get_auth_flow_dict_for_idp(
@ -129,7 +129,7 @@ class LoginRestServlet(RestServlet):
) )
for idp in self._sso_handler.get_identity_providers().values() for idp in self._sso_handler.get_identity_providers().values()
], ],
} # type: JsonDict }
if self._msc2858_enabled: if self._msc2858_enabled:
# backwards-compatibility support for clients which don't # backwards-compatibility support for clients which don't
@ -447,7 +447,7 @@ def _get_auth_flow_dict_for_idp(
use_unstable_brands: whether we should use brand identifiers suitable use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API for the unstable API
""" """
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
if idp.idp_icon: if idp.idp_icon:
e["icon"] = idp.idp_icon e["icon"] = idp.idp_icon
if idp.idp_brand: if idp.idp_brand:
@ -561,7 +561,7 @@ class SsoRedirectServlet(RestServlet):
finish_request(request) finish_request(request)
return return
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore args: Dict[bytes, List[bytes]] = request.args # type: ignore
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
sso_url = await self._sso_handler.handle_redirect_request( sso_url = await self._sso_handler.handle_redirect_request(
request, request,

View file

@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None) server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100)) # type: Optional[int] limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter = Filter( event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
json_decoder.decode(filter_json)
) # type: Optional[Filter]
if ( if (
event_filter event_filter
and event_filter.filter_json.get("event_format", "client") and event_filter.filter_json.get("event_format", "client")
@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter = Filter( event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
json_decoder.decode(filter_json)
) # type: Optional[Filter]
else: else:
event_filter = None event_filter = None

View file

@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"] requester, message_type, content["messages"]
) )
response = (200, {}) # type: Tuple[int, dict] response: Tuple[int, dict] = (200, {})
return response return response

View file

@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
has_consented = False has_consented = False
public_version = username == "" public_version = username == ""
if not public_version: if not public_version:
args = request.args # type: Dict[bytes, List[bytes]] args: Dict[bytes, List[bytes]] = request.args
userhmac_bytes = parse_bytes_from_args(args, "h", required=True) userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes) self._check_hash(username, userhmac_bytes)
@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
""" """
version = parse_string(request, "v", required=True) version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
args = request.args # type: Dict[bytes, List[bytes]] args: Dict[bytes, List[bytes]] = request.args
userhmac = parse_bytes_from_args(args, "h", required=True) userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)

View file

@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
query = {server.decode("ascii"): {}} # type: dict query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2: elif len(request.postpath) == 2:
server, key_id = request.postpath server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
# Note that the value is unused. # Note that the value is unused.
cache_misses = {} # type: Dict[str, Dict[str, int]] cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items(): for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results] results = [(result["ts_added_ms"], result) for result in results]

View file

@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try: try:
# The type on postpath seems incorrect in Twisted 21.2.0. # The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore postpath: List[bytes] = request.postpath # type: ignore
assert postpath assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for

View file

@ -78,16 +78,16 @@ class MediaRepository:
Thumbnailer.set_limits(self.max_image_pixels) Thumbnailer.set_limits(self.max_image_pixels)
self.primary_base_path = hs.config.media_store_path # type: str self.primary_base_path: str = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote") self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals = set() # type: Set[str] self.recently_accessed_locals: Set[str] = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -711,7 +711,7 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if # We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one. # they have the same dimensions of a scaled one.
thumbnails = {} # type: Dict[Tuple[int, int, str], str] thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements: for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop": if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method) thumbnails.setdefault((r_width, r_height, r_type), r_method)

View file

@ -191,7 +191,7 @@ class MediaStorage:
for provider in self.storage_providers: for provider in self.storage_providers:
for path in paths: for path in paths:
res = await provider.fetch(path, file_info) # type: Any res: Any = await provider.fetch(path, file_info)
if res: if res:
logger.debug("Streaming %s from %s", path, provider) logger.debug("Streaming %s from %s", path, provider)
return res return res
@ -233,7 +233,7 @@ class MediaStorage:
os.makedirs(dirname) os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = await provider.fetch(path, file_info) # type: Any res: Any = await provider.fetch(path, file_info)
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer( consumer = BackgroundFileConsumer(

View file

@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
# memory cache mapping urls to an ObservableDeferred returning # memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata # JSON-encoded OG metadata
self._cache = ExpiringCache( self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
cache_name="url_previews", cache_name="url_previews",
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR, expiry_ms=ONE_HOUR,
) # type: ExpiringCache[str, ObservableDeferred] )
if self._worker_run_media_background_jobs: if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call( self._cleaner_loop = self.clock.looping_call(
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead. # If this URL can be accessed via oEmbed, use that instead.
url_to_download = url # type: Optional[str] url_to_download: Optional[str] = url
oembed_url = self._get_oembed_url(url) oembed_url = self._get_oembed_url(url)
if oembed_url: if oembed_url:
# The result might be a new URL to download, or it might be HTML content. # The result might be a new URL to download, or it might be HTML content.
@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# "og:video:height" : "720", # "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {} # type: Dict[str, Optional[str]] og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib: if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss # if we've got more than 50 tags, someone is taking the piss

View file

@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
errcode=Codes.TOO_LARGE, errcode=Codes.TOO_LARGE,
) )
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore args: Dict[bytes, List[bytes]] = request.args # type: ignore
upload_name_bytes = parse_bytes_from_args(args, "filename") upload_name_bytes = parse_bytes_from_args(args, "filename")
if upload_name_bytes: if upload_name_bytes:
try: try:
upload_name = upload_name_bytes.decode("utf8") # type: Optional[str] upload_name: Optional[str] = upload_name_bytes.decode("utf8")
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
try: try:
content = request.content # type: IO # type: ignore content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_content(
media_type, upload_name, content, content_length, requester.user media_type, upload_name, content, content_length, requester.user
) )

View file

@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
use_display_name = parse_boolean(request, "use_display_name", default=False) use_display_name = parse_boolean(request, "use_display_name", default=False)
try: try:
emails_to_use = [ emails_to_use: List[str] = [
val.decode("utf-8") for val in request.args.get(b"use_email", []) val.decode("utf-8") for val in request.args.get(b"use_email", [])
] # type: List[str] ]
except ValueError: except ValueError:
raise SynapseError(400, "Query parameter use_email must be utf-8") raise SynapseError(400, "Query parameter use_email must be utf-8")
except SynapseError as e: except SynapseError as e: