mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 05:13:50 +01:00
Use inline type hints in handlers/
and rest/
. (#10382)
This commit is contained in:
parent
36dc15412d
commit
98aec1cc9d
43 changed files with 212 additions and 215 deletions
1
changelog.d/10382.misc
Normal file
1
changelog.d/10382.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert internal type variable syntax to reflect wider ecosystem use.
|
|
@ -38,10 +38,10 @@ class BaseHandler:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
|
self.state_handler = hs.get_state_handler()
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
@ -55,12 +55,12 @@ class BaseHandler:
|
||||||
# Check whether ratelimiting room admin message redaction is enabled
|
# Check whether ratelimiting room admin message redaction is enabled
|
||||||
# by the presence of rate limits in the config
|
# by the presence of rate limits in the config
|
||||||
if self.hs.config.rc_admin_redaction:
|
if self.hs.config.rc_admin_redaction:
|
||||||
self.admin_redaction_ratelimiter = Ratelimiter(
|
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
|
||||||
store=self.store,
|
store=self.store,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||||
) # type: Optional[Ratelimiter]
|
)
|
||||||
else:
|
else:
|
||||||
self.admin_redaction_ratelimiter = None
|
self.admin_redaction_ratelimiter = None
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
|
||||||
to_key = RoomStreamToken(None, stream_ordering)
|
to_key = RoomStreamToken(None, stream_ordering)
|
||||||
|
|
||||||
# Events that we've processed in this room
|
# Events that we've processed in this room
|
||||||
written_events = set() # type: Set[str]
|
written_events: Set[str] = set()
|
||||||
|
|
||||||
# We need to track gaps in the events stream so that we can then
|
# We need to track gaps in the events stream so that we can then
|
||||||
# write out the state at those events. We do this by keeping track
|
# write out the state at those events. We do this by keeping track
|
||||||
|
@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
|
||||||
# The reverse mapping to above, i.e. map from unseen event to events
|
# The reverse mapping to above, i.e. map from unseen event to events
|
||||||
# that have the unseen event in their prev_events, i.e. the unseen
|
# that have the unseen event in their prev_events, i.e. the unseen
|
||||||
# events "children".
|
# events "children".
|
||||||
unseen_to_child_events = {} # type: Dict[str, Set[str]]
|
unseen_to_child_events: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
# We fetch events in the room the user could see by fetching *all*
|
# We fetch events in the room the user could see by fetching *all*
|
||||||
# events that we have and then filtering, this isn't the most
|
# events that we have and then filtering, this isn't the most
|
||||||
|
|
|
@ -96,7 +96,7 @@ class ApplicationServicesHandler:
|
||||||
self.current_max, limit
|
self.current_max, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
events_by_room = {} # type: Dict[str, List[EventBase]]
|
events_by_room: Dict[str, List[EventBase]] = {}
|
||||||
for event in events:
|
for event in events:
|
||||||
events_by_room.setdefault(event.room_id, []).append(event)
|
events_by_room.setdefault(event.room_id, []).append(event)
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ class ApplicationServicesHandler:
|
||||||
async def _handle_presence(
|
async def _handle_presence(
|
||||||
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
self, service: ApplicationService, users: Collection[Union[str, UserID]]
|
||||||
) -> List[JsonDict]:
|
) -> List[JsonDict]:
|
||||||
events = [] # type: List[JsonDict]
|
events: List[JsonDict] = []
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources["presence"]
|
||||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||||
service, "presence"
|
service, "presence"
|
||||||
|
@ -375,7 +375,7 @@ class ApplicationServicesHandler:
|
||||||
self, only_protocol: Optional[str] = None
|
self, only_protocol: Optional[str] = None
|
||||||
) -> Dict[str, JsonDict]:
|
) -> Dict[str, JsonDict]:
|
||||||
services = self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
protocols = {} # type: Dict[str, List[JsonDict]]
|
protocols: Dict[str, List[JsonDict]] = {}
|
||||||
|
|
||||||
# Collect up all the individual protocol responses out of the ASes
|
# Collect up all the individual protocol responses out of the ASes
|
||||||
for s in services:
|
for s in services:
|
||||||
|
|
|
@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
|
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
|
||||||
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
|
||||||
inst = auth_checker_class(hs)
|
inst = auth_checker_class(hs)
|
||||||
if inst.is_enabled():
|
if inst.is_enabled():
|
||||||
|
@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
# A mapping of user ID to extra attributes to include in the login
|
# A mapping of user ID to extra attributes to include in the login
|
||||||
# response.
|
# response.
|
||||||
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
|
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
|
||||||
|
|
||||||
async def validate_user_via_ui_auth(
|
async def validate_user_via_ui_auth(
|
||||||
self,
|
self,
|
||||||
|
@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
|
||||||
all the stages in any of the permitted flows.
|
all the stages in any of the permitted flows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sid = None # type: Optional[str]
|
sid: Optional[str] = None
|
||||||
authdict = clientdict.pop("auth", {})
|
authdict = clientdict.pop("auth", {})
|
||||||
if "session" in authdict:
|
if "session" in authdict:
|
||||||
sid = authdict["session"]
|
sid = authdict["session"]
|
||||||
|
@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check auth type currently being presented
|
# check auth type currently being presented
|
||||||
errordict = {} # type: Dict[str, Any]
|
errordict: Dict[str, Any] = {}
|
||||||
if "type" in authdict:
|
if "type" in authdict:
|
||||||
login_type = authdict["type"] # type: str
|
login_type: str = authdict["type"]
|
||||||
try:
|
try:
|
||||||
result = await self._check_auth_dict(authdict, clientip)
|
result = await self._check_auth_dict(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
|
@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
|
||||||
LoginType.TERMS: self._get_params_terms,
|
LoginType.TERMS: self._get_params_terms,
|
||||||
}
|
}
|
||||||
|
|
||||||
params = {} # type: Dict[str, Any]
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
for f in public_flows:
|
for f in public_flows:
|
||||||
for stage in f:
|
for stage in f:
|
||||||
|
@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
|
||||||
except StoreError:
|
except StoreError:
|
||||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||||
|
|
||||||
user_id_to_verify = await self.get_session_data(
|
user_id_to_verify: str = await self.get_session_data(
|
||||||
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||||
) # type: str
|
)
|
||||||
|
|
||||||
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
|
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
|
||||||
user_id_to_verify
|
user_id_to_verify
|
||||||
|
|
|
@ -171,7 +171,7 @@ class CasHandler:
|
||||||
|
|
||||||
# Iterate through the nodes and pull out the user and any extra attributes.
|
# Iterate through the nodes and pull out the user and any extra attributes.
|
||||||
user = None
|
user = None
|
||||||
attributes = {} # type: Dict[str, List[Optional[str]]]
|
attributes: Dict[str, List[Optional[str]]] = {}
|
||||||
for child in root[0]:
|
for child in root[0]:
|
||||||
if child.tag.endswith("user"):
|
if child.tag.endswith("user"):
|
||||||
user = child.text
|
user = child.text
|
||||||
|
|
|
@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
hosts = set() # type: Set[str]
|
hosts: Set[str] = set()
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
|
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
|
||||||
hosts.discard(self.server_name)
|
hosts.discard(self.server_name)
|
||||||
|
@ -613,20 +613,20 @@ class DeviceListUpdater:
|
||||||
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
|
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
|
||||||
|
|
||||||
# user_id -> list of updates waiting to be handled.
|
# user_id -> list of updates waiting to be handled.
|
||||||
self._pending_updates = (
|
self._pending_updates: Dict[
|
||||||
{}
|
str, List[Tuple[str, str, Iterable[str], JsonDict]]
|
||||||
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
|
] = {}
|
||||||
|
|
||||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||||
# but they're useful to have them about to reduce the number of spurious
|
# but they're useful to have them about to reduce the number of spurious
|
||||||
# resyncs.
|
# resyncs.
|
||||||
self._seen_updates = ExpiringCache(
|
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
|
||||||
cache_name="device_update_edu",
|
cache_name="device_update_edu",
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
max_len=10000,
|
max_len=10000,
|
||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
iterable=True,
|
iterable=True,
|
||||||
) # type: ExpiringCache[str, Set[str]]
|
)
|
||||||
|
|
||||||
# Attempt to resync out of sync device lists every 30s.
|
# Attempt to resync out of sync device lists every 30s.
|
||||||
self._resync_retry_in_progress = False
|
self._resync_retry_in_progress = False
|
||||||
|
@ -755,7 +755,7 @@ class DeviceListUpdater:
|
||||||
"""Given a list of updates for a user figure out if we need to do a full
|
"""Given a list of updates for a user figure out if we need to do a full
|
||||||
resync, or whether we have enough data that we can just apply the delta.
|
resync, or whether we have enough data that we can just apply the delta.
|
||||||
"""
|
"""
|
||||||
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
|
seen_updates: Set[str] = self._seen_updates.get(user_id, set())
|
||||||
|
|
||||||
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
|
||||||
|
|
||||||
|
|
|
@ -203,7 +203,7 @@ class DeviceMessageHandler:
|
||||||
log_kv({"number_of_to_device_messages": len(messages)})
|
log_kv({"number_of_to_device_messages": len(messages)})
|
||||||
set_tag("sender", sender_user_id)
|
set_tag("sender", sender_user_id)
|
||||||
local_messages = {}
|
local_messages = {}
|
||||||
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
for user_id, by_device in messages.items():
|
for user_id, by_device in messages.items():
|
||||||
# Ratelimit local cross-user key requests by the sending device.
|
# Ratelimit local cross-user key requests by the sending device.
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
|
||||||
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
|
||||||
room_id = None
|
room_id = None
|
||||||
if self.hs.is_mine(room_alias):
|
if self.hs.is_mine(room_alias):
|
||||||
result = await self.get_association_from_room_alias(
|
result: Optional[
|
||||||
room_alias
|
RoomAliasMapping
|
||||||
) # type: Optional[RoomAliasMapping]
|
] = await self.get_association_from_room_alias(room_alias)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
room_id = result.room_id
|
room_id = result.room_id
|
||||||
|
|
|
@ -115,9 +115,9 @@ class E2eKeysHandler:
|
||||||
the number of in-flight queries at a time.
|
the number of in-flight queries at a time.
|
||||||
"""
|
"""
|
||||||
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||||
device_keys_query = query_body.get(
|
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||||
"device_keys", {}
|
"device_keys", {}
|
||||||
) # type: Dict[str, Iterable[str]]
|
)
|
||||||
|
|
||||||
# separate users by domain.
|
# separate users by domain.
|
||||||
# make a map from domain to user_id to device_ids
|
# make a map from domain to user_id to device_ids
|
||||||
|
@ -136,7 +136,7 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
# First get local devices.
|
# First get local devices.
|
||||||
# A map of destination -> failure response.
|
# A map of destination -> failure response.
|
||||||
failures = {} # type: Dict[str, JsonDict]
|
failures: Dict[str, JsonDict] = {}
|
||||||
results = {}
|
results = {}
|
||||||
if local_query:
|
if local_query:
|
||||||
local_result = await self.query_local_devices(local_query)
|
local_result = await self.query_local_devices(local_query)
|
||||||
|
@ -151,11 +151,9 @@ class E2eKeysHandler:
|
||||||
|
|
||||||
# Now attempt to get any remote devices from our local cache.
|
# Now attempt to get any remote devices from our local cache.
|
||||||
# A map of destination -> user ID -> device IDs.
|
# A map of destination -> user ID -> device IDs.
|
||||||
remote_queries_not_in_cache = (
|
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
|
||||||
{}
|
|
||||||
) # type: Dict[str, Dict[str, Iterable[str]]]
|
|
||||||
if remote_queries:
|
if remote_queries:
|
||||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
query_list: List[Tuple[str, Optional[str]]] = []
|
||||||
for user_id, device_ids in remote_queries.items():
|
for user_id, device_ids in remote_queries.items():
|
||||||
if device_ids:
|
if device_ids:
|
||||||
query_list.extend(
|
query_list.extend(
|
||||||
|
@ -362,9 +360,9 @@ class E2eKeysHandler:
|
||||||
A map from user_id -> device_id -> device details
|
A map from user_id -> device_id -> device details
|
||||||
"""
|
"""
|
||||||
set_tag("local_query", query)
|
set_tag("local_query", query)
|
||||||
local_query = [] # type: List[Tuple[str, Optional[str]]]
|
local_query: List[Tuple[str, Optional[str]]] = []
|
||||||
|
|
||||||
result_dict = {} # type: Dict[str, Dict[str, dict]]
|
result_dict: Dict[str, Dict[str, dict]] = {}
|
||||||
for user_id, device_ids in query.items():
|
for user_id, device_ids in query.items():
|
||||||
# we use UserID.from_string to catch invalid user ids
|
# we use UserID.from_string to catch invalid user ids
|
||||||
if not self.is_mine(UserID.from_string(user_id)):
|
if not self.is_mine(UserID.from_string(user_id)):
|
||||||
|
@ -402,9 +400,9 @@ class E2eKeysHandler:
|
||||||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Handle a device key query from a federated server"""
|
"""Handle a device key query from a federated server"""
|
||||||
device_keys_query = query_body.get(
|
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
|
||||||
"device_keys", {}
|
"device_keys", {}
|
||||||
) # type: Dict[str, Optional[List[str]]]
|
)
|
||||||
res = await self.query_local_devices(device_keys_query)
|
res = await self.query_local_devices(device_keys_query)
|
||||||
ret = {"device_keys": res}
|
ret = {"device_keys": res}
|
||||||
|
|
||||||
|
@ -421,8 +419,8 @@ class E2eKeysHandler:
|
||||||
async def claim_one_time_keys(
|
async def claim_one_time_keys(
|
||||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
local_query = [] # type: List[Tuple[str, str, str]]
|
local_query: List[Tuple[str, str, str]] = []
|
||||||
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
|
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||||
|
|
||||||
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
|
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
|
||||||
# we use UserID.from_string to catch invalid user ids
|
# we use UserID.from_string to catch invalid user ids
|
||||||
|
@ -439,8 +437,8 @@ class E2eKeysHandler:
|
||||||
results = await self.store.claim_e2e_one_time_keys(local_query)
|
results = await self.store.claim_e2e_one_time_keys(local_query)
|
||||||
|
|
||||||
# A map of user ID -> device ID -> key ID -> key.
|
# A map of user ID -> device ID -> key ID -> key.
|
||||||
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||||
failures = {} # type: Dict[str, JsonDict]
|
failures: Dict[str, JsonDict] = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, keys in device_keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
for key_id, json_str in keys.items():
|
for key_id, json_str in keys.items():
|
||||||
|
@ -768,8 +766,8 @@ class E2eKeysHandler:
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError: if the input is malformed
|
SynapseError: if the input is malformed
|
||||||
"""
|
"""
|
||||||
signature_list = [] # type: List[SignatureListItem]
|
signature_list: List["SignatureListItem"] = []
|
||||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
failures: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
if not signatures:
|
if not signatures:
|
||||||
return signature_list, failures
|
return signature_list, failures
|
||||||
|
|
||||||
|
@ -930,8 +928,8 @@ class E2eKeysHandler:
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError: if the input is malformed
|
SynapseError: if the input is malformed
|
||||||
"""
|
"""
|
||||||
signature_list = [] # type: List[SignatureListItem]
|
signature_list: List["SignatureListItem"] = []
|
||||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
failures: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
if not signatures:
|
if not signatures:
|
||||||
return signature_list, failures
|
return signature_list, failures
|
||||||
|
|
||||||
|
@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
|
||||||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||||
|
|
||||||
# user_id -> list of updates waiting to be handled.
|
# user_id -> list of updates waiting to be handled.
|
||||||
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
|
||||||
|
|
||||||
async def incoming_signing_key_update(
|
async def incoming_signing_key_update(
|
||||||
self, origin: str, edu_content: JsonDict
|
self, origin: str, edu_content: JsonDict
|
||||||
|
@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
|
||||||
# This can happen since we batch updates
|
# This can happen since we batch updates
|
||||||
return
|
return
|
||||||
|
|
||||||
device_ids = [] # type: List[str]
|
device_ids: List[str] = []
|
||||||
|
|
||||||
logger.info("pending updates: %r", pending_updates)
|
logger.info("pending updates: %r", pending_updates)
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
|
|
||||||
# When the user joins a new room, or another user joins a currently
|
# When the user joins a new room, or another user joins a currently
|
||||||
# joined room, we need to send down presence for those users.
|
# joined room, we need to send down presence for those users.
|
||||||
to_add = [] # type: List[JsonDict]
|
to_add: List[JsonDict] = []
|
||||||
for event in events:
|
for event in events:
|
||||||
if not isinstance(event, EventBase):
|
if not isinstance(event, EventBase):
|
||||||
continue
|
continue
|
||||||
|
@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
|
||||||
# Send down presence.
|
# Send down presence.
|
||||||
if event.state_key == auth_user_id:
|
if event.state_key == auth_user_id:
|
||||||
# Send down presence for everyone in the room.
|
# Send down presence for everyone in the room.
|
||||||
users = await self.store.get_users_in_room(
|
users: Iterable[str] = await self.store.get_users_in_room(
|
||||||
event.room_id
|
event.room_id
|
||||||
) # type: Iterable[str]
|
)
|
||||||
else:
|
else:
|
||||||
users = [event.state_key]
|
users = [event.state_key]
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# When joining a room we need to queue any events for that room up.
|
# When joining a room we need to queue any events for that room up.
|
||||||
# For each room, a list of (pdu, origin) tuples.
|
# For each room, a list of (pdu, origin) tuples.
|
||||||
self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
|
self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
|
||||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||||
|
|
||||||
self._room_backfill = Linearizer("room_backfill")
|
self._room_backfill = Linearizer("room_backfill")
|
||||||
|
@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
|
||||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||||
|
|
||||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||||
state_maps = list(ours.values()) # type: List[StateMap[str]]
|
state_maps: List[StateMap[str]] = list(ours.values())
|
||||||
|
|
||||||
# we don't need this any more, let's delete it.
|
# we don't need this any more, let's delete it.
|
||||||
del ours
|
del ours
|
||||||
|
@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
|
||||||
# exact key to expect. Otherwise check it matches any key we
|
# exact key to expect. Otherwise check it matches any key we
|
||||||
# have for that device.
|
# have for that device.
|
||||||
|
|
||||||
current_keys = [] # type: Container[str]
|
current_keys: Container[str] = []
|
||||||
|
|
||||||
if device:
|
if device:
|
||||||
keys = device.get("keys", {}).get("keys", {})
|
keys = device.get("keys", {}).get("keys", {})
|
||||||
|
@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
|
||||||
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
||||||
]
|
]
|
||||||
|
|
||||||
joined_domains = {} # type: Dict[str, int]
|
joined_domains: Dict[str, int] = {}
|
||||||
for u, d in joined_users:
|
for u, d in joined_users:
|
||||||
try:
|
try:
|
||||||
dom = get_domain_from_id(u)
|
dom = get_domain_from_id(u)
|
||||||
|
@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
|
||||||
event_map = {} # type: Dict[str, EventBase]
|
event_map: Dict[str, EventBase] = {}
|
||||||
|
|
||||||
async def get_event(event_id: str):
|
async def get_event(event_id: str):
|
||||||
with nested_logging_context(event_id):
|
with nested_logging_context(event_id):
|
||||||
|
@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# Ask the remote server to create a valid knock event for us. Once received,
|
# Ask the remote server to create a valid knock event for us. Once received,
|
||||||
# we sign the event
|
# we sign the event
|
||||||
params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
|
params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
|
||||||
origin, event, event_format_version = await self._make_and_verify_event(
|
origin, event, event_format_version = await self._make_and_verify_event(
|
||||||
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
|
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
|
||||||
)
|
)
|
||||||
|
@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
|
||||||
state_sets_d = await self.state_store.get_state_groups(
|
state_sets_d = await self.state_store.get_state_groups(
|
||||||
event.room_id, extrem_ids
|
event.room_id, extrem_ids
|
||||||
)
|
)
|
||||||
state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
|
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
||||||
state_sets.append(state)
|
state_sets.append(state)
|
||||||
current_states = await self.state_handler.resolve_events(
|
current_states = await self.state_handler.resolve_events(
|
||||||
room_version, state_sets, event
|
room_version, state_sets, event
|
||||||
)
|
)
|
||||||
current_state_ids = {
|
current_state_ids: StateMap[str] = {
|
||||||
k: e.event_id for k, e in current_states.items()
|
k: e.event_id for k, e in current_states.items()
|
||||||
} # type: StateMap[str]
|
}
|
||||||
else:
|
else:
|
||||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||||
event.room_id, latest_event_ids=extrem_ids
|
event.room_id, latest_event_ids=extrem_ids
|
||||||
|
@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
# exclude the state key of the new event from the current_state in the context.
|
# exclude the state key of the new event from the current_state in the context.
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
|
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
|
||||||
else:
|
else:
|
||||||
event_key = None
|
event_key = None
|
||||||
state_updates = {
|
state_updates = {
|
||||||
|
@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
logger.debug("Checking auth on event %r", event.content)
|
logger.debug("Checking auth on event %r", event.content)
|
||||||
|
|
||||||
last_exception = None # type: Optional[Exception]
|
last_exception: Optional[Exception] = None
|
||||||
|
|
||||||
# for each public key in the 3pid invite event
|
# for each public key in the 3pid invite event
|
||||||
for public_key_object in event_auth.get_public_keys(invite_event):
|
for public_key_object in event_auth.get_public_keys(invite_event):
|
||||||
|
|
|
@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
|
||||||
async def bulk_get_publicised_groups(
|
async def bulk_get_publicised_groups(
|
||||||
self, user_ids: Iterable[str], proxy: bool = True
|
self, user_ids: Iterable[str], proxy: bool = True
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
destinations = {} # type: Dict[str, Set[str]]
|
destinations: Dict[str, Set[str]] = {}
|
||||||
local_users = set()
|
local_users = set()
|
||||||
|
|
||||||
for user_id in user_ids:
|
for user_id in user_ids:
|
||||||
|
@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
|
||||||
raise SynapseError(400, "Some user_ids are not local")
|
raise SynapseError(400, "Some user_ids are not local")
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
failed_results = [] # type: List[str]
|
failed_results: List[str] = []
|
||||||
for destination, dest_user_ids in destinations.items():
|
for destination, dest_user_ids in destinations.items():
|
||||||
try:
|
try:
|
||||||
r = await self.transport_client.bulk_get_publicised_groups(
|
r = await self.transport_client.bulk_get_publicised_groups(
|
||||||
|
|
|
@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
self.snapshot_cache = ResponseCache(
|
self.snapshot_cache: ResponseCache[
|
||||||
hs.get_clock(), "initial_sync_cache"
|
Tuple[
|
||||||
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
|
str,
|
||||||
|
Optional[StreamToken],
|
||||||
|
Optional[StreamToken],
|
||||||
|
str,
|
||||||
|
Optional[int],
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
]
|
||||||
|
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
|
|
|
@ -81,7 +81,7 @@ class MessageHandler:
|
||||||
|
|
||||||
# The scheduled call to self._expire_event. None if no call is currently
|
# The scheduled call to self._expire_event. None if no call is currently
|
||||||
# scheduled.
|
# scheduled.
|
||||||
self._scheduled_expiry = None # type: Optional[IDelayedCall]
|
self._scheduled_expiry: Optional[IDelayedCall] = None
|
||||||
|
|
||||||
if not hs.config.worker_app:
|
if not hs.config.worker_app:
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
|
@ -196,9 +196,7 @@ class MessageHandler:
|
||||||
room_state_events = await self.state_store.get_state_for_events(
|
room_state_events = await self.state_store.get_state_for_events(
|
||||||
[event.event_id], state_filter=state_filter
|
[event.event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
room_state = room_state_events[
|
room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
|
||||||
event.event_id
|
|
||||||
] # type: Mapping[Any, EventBase]
|
|
||||||
else:
|
else:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
|
@ -421,9 +419,9 @@ class EventCreationHandler:
|
||||||
self.action_generator = hs.get_action_generator()
|
self.action_generator = hs.get_action_generator()
|
||||||
|
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
self.third_party_event_rules = (
|
self.third_party_event_rules: "ThirdPartyEventRules" = (
|
||||||
self.hs.get_third_party_event_rules()
|
self.hs.get_third_party_event_rules()
|
||||||
) # type: ThirdPartyEventRules
|
)
|
||||||
|
|
||||||
self._block_events_without_consent_error = (
|
self._block_events_without_consent_error = (
|
||||||
self.config.block_events_without_consent_error
|
self.config.block_events_without_consent_error
|
||||||
|
@ -440,7 +438,7 @@ class EventCreationHandler:
|
||||||
#
|
#
|
||||||
# map from room id to time-of-last-attempt.
|
# map from room id to time-of-last-attempt.
|
||||||
#
|
#
|
||||||
self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
|
self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
|
||||||
# The number of forward extremeities before a dummy event is sent.
|
# The number of forward extremeities before a dummy event is sent.
|
||||||
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
self._dummy_events_threshold = hs.config.dummy_events_threshold
|
||||||
|
|
||||||
|
@ -465,9 +463,7 @@ class EventCreationHandler:
|
||||||
# Stores the state groups we've recently added to the joined hosts
|
# Stores the state groups we've recently added to the joined hosts
|
||||||
# external cache. Note that the timeout must be significantly less than
|
# external cache. Note that the timeout must be significantly less than
|
||||||
# the TTL on the external cache.
|
# the TTL on the external cache.
|
||||||
self._external_cache_joined_hosts_updates = (
|
self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
|
||||||
None
|
|
||||||
) # type: Optional[ExpiringCache]
|
|
||||||
if self._external_cache.is_enabled():
|
if self._external_cache.is_enabled():
|
||||||
self._external_cache_joined_hosts_updates = ExpiringCache(
|
self._external_cache_joined_hosts_updates = ExpiringCache(
|
||||||
"_external_cache_joined_hosts_updates",
|
"_external_cache_joined_hosts_updates",
|
||||||
|
@ -1299,7 +1295,7 @@ class EventCreationHandler:
|
||||||
# Validate a newly added alias or newly added alt_aliases.
|
# Validate a newly added alias or newly added alt_aliases.
|
||||||
|
|
||||||
original_alias = None
|
original_alias = None
|
||||||
original_alt_aliases = [] # type: List[str]
|
original_alt_aliases: List[str] = []
|
||||||
|
|
||||||
original_event_id = event.unsigned.get("replaces_state")
|
original_event_id = event.unsigned.get("replaces_state")
|
||||||
if original_event_id:
|
if original_event_id:
|
||||||
|
|
|
@ -105,9 +105,9 @@ class OidcHandler:
|
||||||
assert provider_confs
|
assert provider_confs
|
||||||
|
|
||||||
self._token_generator = OidcSessionTokenGenerator(hs)
|
self._token_generator = OidcSessionTokenGenerator(hs)
|
||||||
self._providers = {
|
self._providers: Dict[str, "OidcProvider"] = {
|
||||||
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
|
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
|
||||||
} # type: Dict[str, OidcProvider]
|
}
|
||||||
|
|
||||||
async def load_metadata(self) -> None:
|
async def load_metadata(self) -> None:
|
||||||
"""Validate the config and load the metadata from the remote endpoint.
|
"""Validate the config and load the metadata from the remote endpoint.
|
||||||
|
@ -178,7 +178,7 @@ class OidcHandler:
|
||||||
# are two.
|
# are two.
|
||||||
|
|
||||||
for cookie_name, _ in _SESSION_COOKIES:
|
for cookie_name, _ in _SESSION_COOKIES:
|
||||||
session = request.getCookie(cookie_name) # type: Optional[bytes]
|
session: Optional[bytes] = request.getCookie(cookie_name)
|
||||||
if session is not None:
|
if session is not None:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -277,7 +277,7 @@ class OidcProvider:
|
||||||
self._token_generator = token_generator
|
self._token_generator = token_generator
|
||||||
|
|
||||||
self._config = provider
|
self._config = provider
|
||||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
self._callback_url: str = hs.config.oidc_callback_url
|
||||||
|
|
||||||
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
|
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
|
||||||
# We'll insert this into the Path= parameter of any session cookies we set.
|
# We'll insert this into the Path= parameter of any session cookies we set.
|
||||||
|
@ -290,7 +290,7 @@ class OidcProvider:
|
||||||
self._scopes = provider.scopes
|
self._scopes = provider.scopes
|
||||||
self._user_profile_method = provider.user_profile_method
|
self._user_profile_method = provider.user_profile_method
|
||||||
|
|
||||||
client_secret = None # type: Union[None, str, JwtClientSecret]
|
client_secret: Optional[Union[str, JwtClientSecret]] = None
|
||||||
if provider.client_secret:
|
if provider.client_secret:
|
||||||
client_secret = provider.client_secret
|
client_secret = provider.client_secret
|
||||||
elif provider.client_secret_jwt_key:
|
elif provider.client_secret_jwt_key:
|
||||||
|
@ -305,7 +305,7 @@ class OidcProvider:
|
||||||
provider.client_id,
|
provider.client_id,
|
||||||
client_secret,
|
client_secret,
|
||||||
provider.client_auth_method,
|
provider.client_auth_method,
|
||||||
) # type: ClientAuth
|
)
|
||||||
self._client_auth_method = provider.client_auth_method
|
self._client_auth_method = provider.client_auth_method
|
||||||
|
|
||||||
# cache of metadata for the identity provider (endpoint uris, mostly). This is
|
# cache of metadata for the identity provider (endpoint uris, mostly). This is
|
||||||
|
@ -324,7 +324,7 @@ class OidcProvider:
|
||||||
self._allow_existing_users = provider.allow_existing_users
|
self._allow_existing_users = provider.allow_existing_users
|
||||||
|
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._server_name = hs.config.server_name # type: str
|
self._server_name: str = hs.config.server_name
|
||||||
|
|
||||||
# identifier for the external_ids table
|
# identifier for the external_ids table
|
||||||
self.idp_id = provider.idp_id
|
self.idp_id = provider.idp_id
|
||||||
|
@ -1381,7 +1381,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
if display_name == "":
|
if display_name == "":
|
||||||
display_name = None
|
display_name = None
|
||||||
|
|
||||||
emails = [] # type: List[str]
|
emails: List[str] = []
|
||||||
email = render_template_field(self._config.email_template)
|
email = render_template_field(self._config.email_template)
|
||||||
if email:
|
if email:
|
||||||
emails.append(email)
|
emails.append(email)
|
||||||
|
@ -1391,7 +1391,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||||
extras = {} # type: Dict[str, str]
|
extras: Dict[str, str] = {}
|
||||||
for key, template in self._config.extra_attributes.items():
|
for key, template in self._config.extra_attributes.items():
|
||||||
try:
|
try:
|
||||||
extras[key] = template.render(user=userinfo).strip()
|
extras[key] = template.render(user=userinfo).strip()
|
||||||
|
|
|
@ -81,9 +81,9 @@ class PaginationHandler:
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
self.pagination_lock = ReadWriteLock()
|
self.pagination_lock = ReadWriteLock()
|
||||||
self._purges_in_progress_by_room = set() # type: Set[str]
|
self._purges_in_progress_by_room: Set[str] = set()
|
||||||
# map from purge id to PurgeStatus
|
# map from purge id to PurgeStatus
|
||||||
self._purges_by_id = {} # type: Dict[str, PurgeStatus]
|
self._purges_by_id: Dict[str, PurgeStatus] = {}
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
|
||||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||||
|
|
|
@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
# The number of ongoing syncs on this process, by user id.
|
# The number of ongoing syncs on this process, by user id.
|
||||||
# Empty if _presence_enabled is false.
|
# Empty if _presence_enabled is false.
|
||||||
self._user_to_num_current_syncs = {} # type: Dict[str, int]
|
self._user_to_num_current_syncs: Dict[str, int] = {}
|
||||||
|
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.instance_id = hs.get_instance_id()
|
self.instance_id = hs.get_instance_id()
|
||||||
|
|
||||||
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
|
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
|
||||||
# we haven't notified the presence writer of that yet
|
# we haven't notified the presence writer of that yet
|
||||||
self.users_going_offline = {} # type: Dict[str, int]
|
self.users_going_offline: Dict[str, int] = {}
|
||||||
|
|
||||||
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
|
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
|
||||||
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
|
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
|
||||||
|
@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
# Set of users who have presence in the `user_to_current_state` that
|
# Set of users who have presence in the `user_to_current_state` that
|
||||||
# have not yet been persisted
|
# have not yet been persisted
|
||||||
self.unpersisted_users_changes = set() # type: Set[str]
|
self.unpersisted_users_changes: Set[str] = set()
|
||||||
|
|
||||||
hs.get_reactor().addSystemEventTrigger(
|
hs.get_reactor().addSystemEventTrigger(
|
||||||
"before",
|
"before",
|
||||||
|
@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
|
|
||||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||||
# this is non zero a user will never go offline.
|
# this is non zero a user will never go offline.
|
||||||
self.user_to_num_current_syncs = {} # type: Dict[str, int]
|
self.user_to_num_current_syncs: Dict[str, int] = {}
|
||||||
|
|
||||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||||
# While any sync is ongoing on another process the user will never
|
# While any sync is ongoing on another process the user will never
|
||||||
|
@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
|
||||||
# we assume that all the sync requests on that process have stopped.
|
# we assume that all the sync requests on that process have stopped.
|
||||||
# Stored as a dict from process_id to set of user_id, and a dict of
|
# Stored as a dict from process_id to set of user_id, and a dict of
|
||||||
# process_id to millisecond timestamp last updated.
|
# process_id to millisecond timestamp last updated.
|
||||||
self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]]
|
self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
|
||||||
self.external_process_last_updated_ms = {} # type: Dict[str, int]
|
self.external_process_last_updated_ms: Dict[str, int] = {}
|
||||||
|
|
||||||
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
|
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
|
||||||
|
|
||||||
|
@ -1581,9 +1581,7 @@ class PresenceEventSource:
|
||||||
|
|
||||||
# The set of users that we're interested in and that have had a presence update.
|
# The set of users that we're interested in and that have had a presence update.
|
||||||
# We'll actually pull the presence updates for these users at the end.
|
# We'll actually pull the presence updates for these users at the end.
|
||||||
interested_and_updated_users = (
|
interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
|
||||||
set()
|
|
||||||
) # type: Union[Set[str], FrozenSet[str]]
|
|
||||||
|
|
||||||
if from_key:
|
if from_key:
|
||||||
# First get all users that have had a presence update
|
# First get all users that have had a presence update
|
||||||
|
@ -1950,8 +1948,8 @@ async def get_interested_parties(
|
||||||
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
A 2-tuple of `(room_ids_to_states, users_to_states)`,
|
||||||
with each item being a dict of `entity_name` -> `[UserPresenceState]`
|
with each item being a dict of `entity_name` -> `[UserPresenceState]`
|
||||||
"""
|
"""
|
||||||
room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
|
room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
|
||||||
users_to_states = {} # type: Dict[str, List[UserPresenceState]]
|
users_to_states: Dict[str, List[UserPresenceState]] = {}
|
||||||
for state in states:
|
for state in states:
|
||||||
room_ids = await store.get_rooms_for_user(state.user_id)
|
room_ids = await store.get_rooms_for_user(state.user_id)
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
|
@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
|
||||||
# stream_id, destinations, user_ids)`. We don't store the full states
|
# stream_id, destinations, user_ids)`. We don't store the full states
|
||||||
# for efficiency, and remote workers will already have the full states
|
# for efficiency, and remote workers will already have the full states
|
||||||
# cached.
|
# cached.
|
||||||
self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
|
self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
|
||||||
|
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
|
|
||||||
# Map from instance name to current token
|
# Map from instance name to current token
|
||||||
self._current_tokens = {} # type: Dict[str, int]
|
self._current_tokens: Dict[str, int] = {}
|
||||||
|
|
||||||
if self._queue_presence_updates:
|
if self._queue_presence_updates:
|
||||||
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
|
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
|
||||||
|
@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
|
||||||
# handle the case where `from_token` stream ID has already been dropped.
|
# handle the case where `from_token` stream ID has already been dropped.
|
||||||
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
|
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
|
||||||
|
|
||||||
to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
|
to_send: List[Tuple[int, Tuple[str, str]]] = []
|
||||||
limited = False
|
limited = False
|
||||||
new_id = upto_token
|
new_id = upto_token
|
||||||
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
|
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
|
||||||
|
@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
|
||||||
if not self._federation:
|
if not self._federation:
|
||||||
return
|
return
|
||||||
|
|
||||||
hosts_to_users = {} # type: Dict[str, Set[str]]
|
hosts_to_users: Dict[str, Set[str]] = {}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
|
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
|
||||||
|
|
||||||
|
|
|
@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
|
||||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||||
)
|
)
|
||||||
|
|
||||||
displayname_to_set = new_displayname # type: Optional[str]
|
displayname_to_set: Optional[str] = new_displayname
|
||||||
if new_displayname == "":
|
if new_displayname == "":
|
||||||
displayname_to_set = None
|
displayname_to_set = None
|
||||||
|
|
||||||
|
@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
|
||||||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||||
)
|
)
|
||||||
|
|
||||||
avatar_url_to_set = new_avatar_url # type: Optional[str]
|
avatar_url_to_set: Optional[str] = new_avatar_url
|
||||||
if new_avatar_url == "":
|
if new_avatar_url == "":
|
||||||
avatar_url_to_set = None
|
avatar_url_to_set = None
|
||||||
|
|
||||||
|
|
|
@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
|
|
||||||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
||||||
"""Takes a list of receipts, stores them and informs the notifier."""
|
"""Takes a list of receipts, stores them and informs the notifier."""
|
||||||
min_batch_id = None # type: Optional[int]
|
min_batch_id: Optional[int] = None
|
||||||
max_batch_id = None # type: Optional[int]
|
max_batch_id: Optional[int] = None
|
||||||
|
|
||||||
for receipt in receipts:
|
for receipt in receipts:
|
||||||
res = await self.store.insert_receipt(
|
res = await self.store.insert_receipt(
|
||||||
|
|
|
@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
# Room state based off defined presets
|
# Room state based off defined presets
|
||||||
self._presets_dict = {
|
self._presets_dict: Dict[str, Dict[str, Any]] = {
|
||||||
RoomCreationPreset.PRIVATE_CHAT: {
|
RoomCreationPreset.PRIVATE_CHAT: {
|
||||||
"join_rules": JoinRules.INVITE,
|
"join_rules": JoinRules.INVITE,
|
||||||
"history_visibility": HistoryVisibility.SHARED,
|
"history_visibility": HistoryVisibility.SHARED,
|
||||||
|
@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
"guest_can_join": False,
|
"guest_can_join": False,
|
||||||
"power_level_content_override": {},
|
"power_level_content_override": {},
|
||||||
},
|
},
|
||||||
} # type: Dict[str, Dict[str, Any]]
|
}
|
||||||
|
|
||||||
# Modify presets to selectively enable encryption by default per homeserver config
|
# Modify presets to selectively enable encryption by default per homeserver config
|
||||||
for preset_name, preset_config in self._presets_dict.items():
|
for preset_name, preset_config in self._presets_dict.items():
|
||||||
|
@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
|
||||||
# If a user tries to update the same room multiple times in quick
|
# If a user tries to update the same room multiple times in quick
|
||||||
# succession, only process the first attempt and return its result to
|
# succession, only process the first attempt and return its result to
|
||||||
# subsequent requests
|
# subsequent requests
|
||||||
self._upgrade_response_cache = ResponseCache(
|
self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
|
||||||
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
||||||
) # type: ResponseCache[Tuple[str, str]]
|
)
|
||||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||||
|
|
||||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||||
|
@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
|
||||||
if not await self.spam_checker.user_may_create_room(user_id):
|
if not await self.spam_checker.user_may_create_room(user_id):
|
||||||
raise SynapseError(403, "You are not permitted to create rooms")
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
||||||
creation_content = {
|
creation_content: JsonDict = {
|
||||||
"room_version": new_room_version.identifier,
|
"room_version": new_room_version.identifier,
|
||||||
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
||||||
} # type: JsonDict
|
}
|
||||||
|
|
||||||
# Check if old room was non-federatable
|
# Check if old room was non-federatable
|
||||||
|
|
||||||
|
@ -936,7 +936,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
etype=EventTypes.PowerLevels, content=pl_content
|
etype=EventTypes.PowerLevels, content=pl_content
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
power_level_content = {
|
power_level_content: JsonDict = {
|
||||||
"users": {creator_id: 100},
|
"users": {creator_id: 100},
|
||||||
"users_default": 0,
|
"users_default": 0,
|
||||||
"events": {
|
"events": {
|
||||||
|
@ -955,7 +955,7 @@ class RoomCreationHandler(BaseHandler):
|
||||||
"kick": 50,
|
"kick": 50,
|
||||||
"redact": 50,
|
"redact": 50,
|
||||||
"invite": 50,
|
"invite": 50,
|
||||||
} # type: JsonDict
|
}
|
||||||
|
|
||||||
if config["original_invitees_have_ops"]:
|
if config["original_invitees_have_ops"]:
|
||||||
for invitee in invite_list:
|
for invitee in invite_list:
|
||||||
|
|
|
@ -47,12 +47,12 @@ class RoomListHandler(BaseHandler):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self.enable_room_list_search = hs.config.enable_room_list_search
|
self.enable_room_list_search = hs.config.enable_room_list_search
|
||||||
self.response_cache = ResponseCache(
|
self.response_cache: ResponseCache[
|
||||||
hs.get_clock(), "room_list"
|
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
|
||||||
) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
|
] = ResponseCache(hs.get_clock(), "room_list")
|
||||||
self.remote_response_cache = ResponseCache(
|
self.remote_response_cache: ResponseCache[
|
||||||
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
|
Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
|
||||||
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
|
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
|
||||||
|
|
||||||
async def get_local_public_room_list(
|
async def get_local_public_room_list(
|
||||||
self,
|
self,
|
||||||
|
@ -139,10 +139,10 @@ class RoomListHandler(BaseHandler):
|
||||||
if since_token:
|
if since_token:
|
||||||
batch_token = RoomListNextBatch.from_token(since_token)
|
batch_token = RoomListNextBatch.from_token(since_token)
|
||||||
|
|
||||||
bounds = (
|
bounds: Optional[Tuple[int, str]] = (
|
||||||
batch_token.last_joined_members,
|
batch_token.last_joined_members,
|
||||||
batch_token.last_room_id,
|
batch_token.last_room_id,
|
||||||
) # type: Optional[Tuple[int, str]]
|
)
|
||||||
forwards = batch_token.direction_is_forward
|
forwards = batch_token.direction_is_forward
|
||||||
has_batch_token = True
|
has_batch_token = True
|
||||||
else:
|
else:
|
||||||
|
@ -182,7 +182,7 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
results = [build_room_entry(r) for r in results]
|
results = [build_room_entry(r) for r in results]
|
||||||
|
|
||||||
response = {} # type: JsonDict
|
response: JsonDict = {}
|
||||||
num_results = len(results)
|
num_results = len(results)
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
more_to_come = num_results == probing_limit
|
more_to_come = num_results == probing_limit
|
||||||
|
|
|
@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
|
||||||
self.unstable_idp_brand = None
|
self.unstable_idp_brand = None
|
||||||
|
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
self._sso_handler.register_identity_provider(self)
|
self._sso_handler.register_identity_provider(self)
|
||||||
|
@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
MXID_MAPPER_MAP = {
|
MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
|
||||||
"hexencode": map_username_to_mxid_localpart,
|
"hexencode": map_username_to_mxid_localpart,
|
||||||
"dotreplace": dot_replace_for_mxid,
|
"dotreplace": dot_replace_for_mxid,
|
||||||
} # type: Dict[str, Callable[[str], str]]
|
}
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
|
|
@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
|
||||||
# If doing a subset of all rooms seearch, check if any of the rooms
|
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||||
# are from an upgraded room, and search their contents as well
|
# are from an upgraded room, and search their contents as well
|
||||||
if search_filter.rooms:
|
if search_filter.rooms:
|
||||||
historical_room_ids = [] # type: List[str]
|
historical_room_ids: List[str] = []
|
||||||
for room_id in search_filter.rooms:
|
for room_id in search_filter.rooms:
|
||||||
# Add any previous rooms to the search if they exist
|
# Add any previous rooms to the search if they exist
|
||||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||||
|
@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
|
||||||
rank_map = {} # event_id -> rank of event
|
rank_map = {} # event_id -> rank of event
|
||||||
allowed_events = []
|
allowed_events = []
|
||||||
# Holds result of grouping by room, if applicable
|
# Holds result of grouping by room, if applicable
|
||||||
room_groups = {} # type: Dict[str, JsonDict]
|
room_groups: Dict[str, JsonDict] = {}
|
||||||
# Holds result of grouping by sender, if applicable
|
# Holds result of grouping by sender, if applicable
|
||||||
sender_group = {} # type: Dict[str, JsonDict]
|
sender_group: Dict[str, JsonDict] = {}
|
||||||
|
|
||||||
# Holds the next_batch for the entire result set if one of those exists
|
# Holds the next_batch for the entire result set if one of those exists
|
||||||
global_next_batch = None
|
global_next_batch = None
|
||||||
|
@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
|
||||||
s["results"].append(e.event_id)
|
s["results"].append(e.event_id)
|
||||||
|
|
||||||
elif order_by == "recent":
|
elif order_by == "recent":
|
||||||
room_events = [] # type: List[EventBase]
|
room_events: List[EventBase] = []
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
pagination_token = batch_token
|
pagination_token = batch_token
|
||||||
|
|
|
@ -90,14 +90,14 @@ class SpaceSummaryHandler:
|
||||||
room_queue = deque((_RoomQueueEntry(room_id, ()),))
|
room_queue = deque((_RoomQueueEntry(room_id, ()),))
|
||||||
|
|
||||||
# rooms we have already processed
|
# rooms we have already processed
|
||||||
processed_rooms = set() # type: Set[str]
|
processed_rooms: Set[str] = set()
|
||||||
|
|
||||||
# events we have already processed. We don't necessarily have their event ids,
|
# events we have already processed. We don't necessarily have their event ids,
|
||||||
# so instead we key on (room id, state key)
|
# so instead we key on (room id, state key)
|
||||||
processed_events = set() # type: Set[Tuple[str, str]]
|
processed_events: Set[Tuple[str, str]] = set()
|
||||||
|
|
||||||
rooms_result = [] # type: List[JsonDict]
|
rooms_result: List[JsonDict] = []
|
||||||
events_result = [] # type: List[JsonDict]
|
events_result: List[JsonDict] = []
|
||||||
|
|
||||||
while room_queue and len(rooms_result) < MAX_ROOMS:
|
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||||
queue_entry = room_queue.popleft()
|
queue_entry = room_queue.popleft()
|
||||||
|
@ -272,10 +272,10 @@ class SpaceSummaryHandler:
|
||||||
# the set of rooms that we should not walk further. Initialise it with the
|
# the set of rooms that we should not walk further. Initialise it with the
|
||||||
# excluded-rooms list; we will add other rooms as we process them so that
|
# excluded-rooms list; we will add other rooms as we process them so that
|
||||||
# we do not loop.
|
# we do not loop.
|
||||||
processed_rooms = set(exclude_rooms) # type: Set[str]
|
processed_rooms: Set[str] = set(exclude_rooms)
|
||||||
|
|
||||||
rooms_result = [] # type: List[JsonDict]
|
rooms_result: List[JsonDict] = []
|
||||||
events_result = [] # type: List[JsonDict]
|
events_result: List[JsonDict] = []
|
||||||
|
|
||||||
while room_queue and len(rooms_result) < MAX_ROOMS:
|
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||||
room_id = room_queue.popleft()
|
room_id = room_queue.popleft()
|
||||||
|
@ -353,7 +353,7 @@ class SpaceSummaryHandler:
|
||||||
max_children = MAX_ROOMS_PER_SPACE
|
max_children = MAX_ROOMS_PER_SPACE
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
events_result = [] # type: List[JsonDict]
|
events_result: List[JsonDict] = []
|
||||||
for edge_event in itertools.islice(child_events, max_children):
|
for edge_event in itertools.islice(child_events, max_children):
|
||||||
events_result.append(
|
events_result.append(
|
||||||
await self._event_serializer.serialize_event(
|
await self._event_serializer.serialize_event(
|
||||||
|
|
|
@ -202,10 +202,10 @@ class SsoHandler:
|
||||||
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
||||||
|
|
||||||
# a map from session id to session data
|
# a map from session id to session data
|
||||||
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
|
||||||
|
|
||||||
# map from idp_id to SsoIdentityProvider
|
# map from idp_id to SsoIdentityProvider
|
||||||
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
|
self._identity_providers: Dict[str, SsoIdentityProvider] = {}
|
||||||
|
|
||||||
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
||||||
|
|
||||||
|
@ -296,7 +296,7 @@ class SsoHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# if the client chose an IdP, use that
|
# if the client chose an IdP, use that
|
||||||
idp = None # type: Optional[SsoIdentityProvider]
|
idp: Optional[SsoIdentityProvider] = None
|
||||||
if idp_id:
|
if idp_id:
|
||||||
idp = self._identity_providers.get(idp_id)
|
idp = self._identity_providers.get(idp_id)
|
||||||
if not idp:
|
if not idp:
|
||||||
|
@ -669,9 +669,9 @@ class SsoHandler:
|
||||||
remote_user_id,
|
remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id_to_verify = await self._auth_handler.get_session_data(
|
user_id_to_verify: str = await self._auth_handler.get_session_data(
|
||||||
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
|
||||||
) # type: str
|
)
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -793,7 +793,7 @@ class SsoHandler:
|
||||||
session.use_display_name = use_display_name
|
session.use_display_name = use_display_name
|
||||||
|
|
||||||
emails_from_idp = set(session.emails)
|
emails_from_idp = set(session.emails)
|
||||||
filtered_emails = set() # type: Set[str]
|
filtered_emails: Set[str] = set()
|
||||||
|
|
||||||
# we iterate through the list rather than just building a set conjunction, so
|
# we iterate through the list rather than just building a set conjunction, so
|
||||||
# that we can log attempts to use unknown addresses
|
# that we can log attempts to use unknown addresses
|
||||||
|
|
|
@ -49,7 +49,7 @@ class StatsHandler:
|
||||||
self.stats_enabled = hs.config.stats_enabled
|
self.stats_enabled = hs.config.stats_enabled
|
||||||
|
|
||||||
# The current position in the current_state_delta stream
|
# The current position in the current_state_delta stream
|
||||||
self.pos = None # type: Optional[int]
|
self.pos: Optional[int] = None
|
||||||
|
|
||||||
# Guard to ensure we only process deltas one at a time
|
# Guard to ensure we only process deltas one at a time
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
|
@ -131,10 +131,10 @@ class StatsHandler:
|
||||||
mapping from room/user ID to changes in the various fields.
|
mapping from room/user ID to changes in the various fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
room_to_stats_deltas: Dict[str, CounterType[str]] = {}
|
||||||
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
user_to_stats_deltas: Dict[str, CounterType[str]] = {}
|
||||||
|
|
||||||
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
|
room_to_state_updates: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
for delta in deltas:
|
for delta in deltas:
|
||||||
typ = delta["type"]
|
typ = delta["type"]
|
||||||
|
@ -164,7 +164,7 @@ class StatsHandler:
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_content = {} # type: JsonDict
|
event_content: JsonDict = {}
|
||||||
|
|
||||||
if event_id is not None:
|
if event_id is not None:
|
||||||
event = await self.store.get_event(event_id, allow_none=True)
|
event = await self.store.get_event(event_id, allow_none=True)
|
||||||
|
|
|
@ -278,12 +278,14 @@ class SyncHandler:
|
||||||
self.state_store = self.storage.state
|
self.state_store = self.storage.state
|
||||||
|
|
||||||
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
||||||
self.lazy_loaded_members_cache = ExpiringCache(
|
self.lazy_loaded_members_cache: ExpiringCache[
|
||||||
|
Tuple[str, Optional[str]], LruCache[str, str]
|
||||||
|
] = ExpiringCache(
|
||||||
"lazy_loaded_members_cache",
|
"lazy_loaded_members_cache",
|
||||||
self.clock,
|
self.clock,
|
||||||
max_len=0,
|
max_len=0,
|
||||||
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
|
||||||
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
|
)
|
||||||
|
|
||||||
async def wait_for_sync_for_user(
|
async def wait_for_sync_for_user(
|
||||||
self,
|
self,
|
||||||
|
@ -440,7 +442,7 @@ class SyncHandler:
|
||||||
)
|
)
|
||||||
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
||||||
|
|
||||||
ephemeral_by_room = {} # type: JsonDict
|
ephemeral_by_room: JsonDict = {}
|
||||||
|
|
||||||
for event in typing:
|
for event in typing:
|
||||||
# we want to exclude the room_id from the event, but modifying the
|
# we want to exclude the room_id from the event, but modifying the
|
||||||
|
@ -502,7 +504,7 @@ class SyncHandler:
|
||||||
# We check if there are any state events, if there are then we pass
|
# We check if there are any state events, if there are then we pass
|
||||||
# all current state events to the filter_events function. This is to
|
# all current state events to the filter_events function. This is to
|
||||||
# ensure that we always include current state in the timeline
|
# ensure that we always include current state in the timeline
|
||||||
current_state_ids = frozenset() # type: FrozenSet[str]
|
current_state_ids: FrozenSet[str] = frozenset()
|
||||||
if any(e.is_state() for e in recents):
|
if any(e.is_state() for e in recents):
|
||||||
current_state_ids_map = await self.store.get_current_state_ids(
|
current_state_ids_map = await self.store.get_current_state_ids(
|
||||||
room_id
|
room_id
|
||||||
|
@ -783,9 +785,9 @@ class SyncHandler:
|
||||||
def get_lazy_loaded_members_cache(
|
def get_lazy_loaded_members_cache(
|
||||||
self, cache_key: Tuple[str, Optional[str]]
|
self, cache_key: Tuple[str, Optional[str]]
|
||||||
) -> LruCache[str, str]:
|
) -> LruCache[str, str]:
|
||||||
cache = self.lazy_loaded_members_cache.get(
|
cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
|
||||||
cache_key
|
cache_key
|
||||||
) # type: Optional[LruCache[str, str]]
|
)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
logger.debug("creating LruCache for %r", cache_key)
|
logger.debug("creating LruCache for %r", cache_key)
|
||||||
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
|
||||||
|
@ -984,7 +986,7 @@ class SyncHandler:
|
||||||
if t[0] == EventTypes.Member:
|
if t[0] == EventTypes.Member:
|
||||||
cache.set(t[1], event_id)
|
cache.set(t[1], event_id)
|
||||||
|
|
||||||
state = {} # type: Dict[str, EventBase]
|
state: Dict[str, EventBase] = {}
|
||||||
if state_ids:
|
if state_ids:
|
||||||
state = await self.store.get_events(list(state_ids.values()))
|
state = await self.store.get_events(list(state_ids.values()))
|
||||||
|
|
||||||
|
@ -1088,8 +1090,8 @@ class SyncHandler:
|
||||||
|
|
||||||
logger.debug("Fetching OTK data")
|
logger.debug("Fetching OTK data")
|
||||||
device_id = sync_config.device_id
|
device_id = sync_config.device_id
|
||||||
one_time_key_counts = {} # type: JsonDict
|
one_time_key_counts: JsonDict = {}
|
||||||
unused_fallback_key_types = [] # type: List[str]
|
unused_fallback_key_types: List[str] = []
|
||||||
if device_id:
|
if device_id:
|
||||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||||
user_id, device_id
|
user_id, device_id
|
||||||
|
@ -1437,7 +1439,7 @@ class SyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if block_all_room_ephemeral:
|
if block_all_room_ephemeral:
|
||||||
ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
|
ephemeral_by_room: Dict[str, List[JsonDict]] = {}
|
||||||
else:
|
else:
|
||||||
now_token, ephemeral_by_room = await self.ephemeral_by_room(
|
now_token, ephemeral_by_room = await self.ephemeral_by_room(
|
||||||
sync_result_builder,
|
sync_result_builder,
|
||||||
|
@ -1468,7 +1470,7 @@ class SyncHandler:
|
||||||
|
|
||||||
# If there is ignored users account data and it matches the proper type,
|
# If there is ignored users account data and it matches the proper type,
|
||||||
# then use it.
|
# then use it.
|
||||||
ignored_users = frozenset() # type: FrozenSet[str]
|
ignored_users: FrozenSet[str] = frozenset()
|
||||||
if ignored_account_data:
|
if ignored_account_data:
|
||||||
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
ignored_users_data = ignored_account_data.get("ignored_users", {})
|
||||||
if isinstance(ignored_users_data, dict):
|
if isinstance(ignored_users_data, dict):
|
||||||
|
@ -1586,7 +1588,7 @@ class SyncHandler:
|
||||||
user_id, since_token.room_key, now_token.room_key
|
user_id, since_token.room_key, now_token.room_key
|
||||||
)
|
)
|
||||||
|
|
||||||
mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
|
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
|
||||||
for event in rooms_changed:
|
for event in rooms_changed:
|
||||||
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
|
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
|
||||||
|
|
||||||
|
@ -1722,7 +1724,7 @@ class SyncHandler:
|
||||||
# This is all screaming out for a refactor, as the logic here is
|
# This is all screaming out for a refactor, as the logic here is
|
||||||
# subtle and the moving parts numerous.
|
# subtle and the moving parts numerous.
|
||||||
if leave_event.internal_metadata.is_out_of_band_membership():
|
if leave_event.internal_metadata.is_out_of_band_membership():
|
||||||
batch_events = [leave_event] # type: Optional[List[EventBase]]
|
batch_events: Optional[List[EventBase]] = [leave_event]
|
||||||
else:
|
else:
|
||||||
batch_events = None
|
batch_events = None
|
||||||
|
|
||||||
|
@ -1971,7 +1973,7 @@ class SyncHandler:
|
||||||
room_id, batch, sync_config, since_token, now_token, full_state=full_state
|
room_id, batch, sync_config, since_token, now_token, full_state=full_state
|
||||||
)
|
)
|
||||||
|
|
||||||
summary = {} # type: Optional[JsonDict]
|
summary: Optional[JsonDict] = {}
|
||||||
|
|
||||||
# we include a summary in room responses when we're lazy loading
|
# we include a summary in room responses when we're lazy loading
|
||||||
# members (as the client otherwise doesn't have enough info to form
|
# members (as the client otherwise doesn't have enough info to form
|
||||||
|
@ -1995,7 +1997,7 @@ class SyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if room_builder.rtype == "joined":
|
if room_builder.rtype == "joined":
|
||||||
unread_notifications = {} # type: Dict[str, int]
|
unread_notifications: Dict[str, int] = {}
|
||||||
room_sync = JoinedSyncResult(
|
room_sync = JoinedSyncResult(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
timeline=batch,
|
timeline=batch,
|
||||||
|
|
|
@ -68,11 +68,11 @@ class FollowerTypingHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# map room IDs to serial numbers
|
# map room IDs to serial numbers
|
||||||
self._room_serials = {} # type: Dict[str, int]
|
self._room_serials: Dict[str, int] = {}
|
||||||
# map room IDs to sets of users currently typing
|
# map room IDs to sets of users currently typing
|
||||||
self._room_typing = {} # type: Dict[str, Set[str]]
|
self._room_typing: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
|
self._member_last_federation_poke: Dict[RoomMember, int] = {}
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||||
self._latest_room_serial = 0
|
self._latest_room_serial = 0
|
||||||
|
|
||||||
|
@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||||
|
|
||||||
# clock time we expect to stop
|
# clock time we expect to stop
|
||||||
self._member_typing_until = {} # type: Dict[RoomMember, int]
|
self._member_typing_until: Dict[RoomMember, int] = {}
|
||||||
|
|
||||||
# caches which room_ids changed at which serials
|
# caches which room_ids changed at which serials
|
||||||
self._typing_stream_change_cache = StreamChangeCache(
|
self._typing_stream_change_cache = StreamChangeCache(
|
||||||
|
@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
changed_rooms: Optional[
|
||||||
last_id
|
Iterable[str]
|
||||||
) # type: Optional[Iterable[str]]
|
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
|
||||||
|
|
||||||
if changed_rooms is None:
|
if changed_rooms is None:
|
||||||
changed_rooms = self._room_serials
|
changed_rooms = self._room_serials
|
||||||
|
|
|
@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||||
self.search_all_users = hs.config.user_directory_search_all_users
|
self.search_all_users = hs.config.user_directory_search_all_users
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
# The current position in the current_state_delta stream
|
# The current position in the current_state_delta stream
|
||||||
self.pos = None # type: Optional[int]
|
self.pos: Optional[int] = None
|
||||||
|
|
||||||
# Guard to ensure we only process deltas one at a time
|
# Guard to ensure we only process deltas one at a time
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
|
|
|
@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||||
|
|
||||||
# Get the room ID from the identifier.
|
# Get the room ID from the identifier.
|
||||||
try:
|
try:
|
||||||
remote_room_hosts = [
|
remote_room_hosts: Optional[List[str]] = [
|
||||||
x.decode("ascii") for x in request.args[b"server_name"]
|
x.decode("ascii") for x in request.args[b"server_name"]
|
||||||
] # type: Optional[List[str]]
|
]
|
||||||
except Exception:
|
except Exception:
|
||||||
remote_room_hosts = None
|
remote_room_hosts = None
|
||||||
room_id, remote_room_hosts = await self.resolve_room_id(
|
room_id, remote_room_hosts = await self.resolve_room_id(
|
||||||
|
@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter = Filter(
|
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||||
json_decoder.decode(filter_json)
|
|
||||||
) # type: Optional[Filter]
|
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
|
|
|
@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.reactor = hs.get_reactor()
|
self.reactor = hs.get_reactor()
|
||||||
self.nonces = {} # type: Dict[str, int]
|
self.nonces: Dict[str, int] = {}
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def _clear_old_nonces(self):
|
def _clear_old_nonces(self):
|
||||||
|
|
|
@ -121,7 +121,7 @@ class LoginRestServlet(RestServlet):
|
||||||
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
||||||
|
|
||||||
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
||||||
sso_flow = {
|
sso_flow: JsonDict = {
|
||||||
"type": LoginRestServlet.SSO_TYPE,
|
"type": LoginRestServlet.SSO_TYPE,
|
||||||
"identity_providers": [
|
"identity_providers": [
|
||||||
_get_auth_flow_dict_for_idp(
|
_get_auth_flow_dict_for_idp(
|
||||||
|
@ -129,7 +129,7 @@ class LoginRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
for idp in self._sso_handler.get_identity_providers().values()
|
for idp in self._sso_handler.get_identity_providers().values()
|
||||||
],
|
],
|
||||||
} # type: JsonDict
|
}
|
||||||
|
|
||||||
if self._msc2858_enabled:
|
if self._msc2858_enabled:
|
||||||
# backwards-compatibility support for clients which don't
|
# backwards-compatibility support for clients which don't
|
||||||
|
@ -447,7 +447,7 @@ def _get_auth_flow_dict_for_idp(
|
||||||
use_unstable_brands: whether we should use brand identifiers suitable
|
use_unstable_brands: whether we should use brand identifiers suitable
|
||||||
for the unstable API
|
for the unstable API
|
||||||
"""
|
"""
|
||||||
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
|
e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
|
||||||
if idp.idp_icon:
|
if idp.idp_icon:
|
||||||
e["icon"] = idp.idp_icon
|
e["icon"] = idp.idp_icon
|
||||||
if idp.idp_brand:
|
if idp.idp_brand:
|
||||||
|
@ -561,7 +561,7 @@ class SsoRedirectServlet(RestServlet):
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
|
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||||
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
|
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
|
||||||
sso_url = await self._sso_handler.handle_redirect_request(
|
sso_url = await self._sso_handler.handle_redirect_request(
|
||||||
request,
|
request,
|
||||||
|
|
|
@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
||||||
server = parse_string(request, "server", default=None)
|
server = parse_string(request, "server", default=None)
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
limit = int(content.get("limit", 100)) # type: Optional[int]
|
limit: Optional[int] = int(content.get("limit", 100))
|
||||||
since_token = content.get("since", None)
|
since_token = content.get("since", None)
|
||||||
search_filter = content.get("filter", None)
|
search_filter = content.get("filter", None)
|
||||||
|
|
||||||
|
@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter = Filter(
|
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||||
json_decoder.decode(filter_json)
|
|
||||||
) # type: Optional[Filter]
|
|
||||||
if (
|
if (
|
||||||
event_filter
|
event_filter
|
||||||
and event_filter.filter_json.get("event_format", "client")
|
and event_filter.filter_json.get("event_format", "client")
|
||||||
|
@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
|
||||||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||||
if filter_str:
|
if filter_str:
|
||||||
filter_json = urlparse.unquote(filter_str)
|
filter_json = urlparse.unquote(filter_str)
|
||||||
event_filter = Filter(
|
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||||
json_decoder.decode(filter_json)
|
|
||||||
) # type: Optional[Filter]
|
|
||||||
else:
|
else:
|
||||||
event_filter = None
|
event_filter = None
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
||||||
requester, message_type, content["messages"]
|
requester, message_type, content["messages"]
|
||||||
)
|
)
|
||||||
|
|
||||||
response = (200, {}) # type: Tuple[int, dict]
|
response: Tuple[int, dict] = (200, {})
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
|
||||||
has_consented = False
|
has_consented = False
|
||||||
public_version = username == ""
|
public_version = username == ""
|
||||||
if not public_version:
|
if not public_version:
|
||||||
args = request.args # type: Dict[bytes, List[bytes]]
|
args: Dict[bytes, List[bytes]] = request.args
|
||||||
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
|
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
|
||||||
|
|
||||||
self._check_hash(username, userhmac_bytes)
|
self._check_hash(username, userhmac_bytes)
|
||||||
|
@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
|
||||||
"""
|
"""
|
||||||
version = parse_string(request, "v", required=True)
|
version = parse_string(request, "v", required=True)
|
||||||
username = parse_string(request, "u", required=True)
|
username = parse_string(request, "u", required=True)
|
||||||
args = request.args # type: Dict[bytes, List[bytes]]
|
args: Dict[bytes, List[bytes]] = request.args
|
||||||
userhmac = parse_bytes_from_args(args, "h", required=True)
|
userhmac = parse_bytes_from_args(args, "h", required=True)
|
||||||
|
|
||||||
self._check_hash(username, userhmac)
|
self._check_hash(username, userhmac)
|
||||||
|
|
|
@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
(server,) = request.postpath
|
(server,) = request.postpath
|
||||||
query = {server.decode("ascii"): {}} # type: dict
|
query: dict = {server.decode("ascii"): {}}
|
||||||
elif len(request.postpath) == 2:
|
elif len(request.postpath) == 2:
|
||||||
server, key_id = request.postpath
|
server, key_id = request.postpath
|
||||||
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
|
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
|
||||||
|
@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
# Note that the value is unused.
|
# Note that the value is unused.
|
||||||
cache_misses = {} # type: Dict[str, Dict[str, int]]
|
cache_misses: Dict[str, Dict[str, int]] = {}
|
||||||
for (server_name, key_id, _), results in cached.items():
|
for (server_name, key_id, _), results in cached.items():
|
||||||
results = [(result["ts_added_ms"], result) for result in results]
|
results = [(result["ts_added_ms"], result) for result in results]
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
|
||||||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||||
try:
|
try:
|
||||||
# The type on postpath seems incorrect in Twisted 21.2.0.
|
# The type on postpath seems incorrect in Twisted 21.2.0.
|
||||||
postpath = request.postpath # type: List[bytes] # type: ignore
|
postpath: List[bytes] = request.postpath # type: ignore
|
||||||
assert postpath
|
assert postpath
|
||||||
|
|
||||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||||
|
|
|
@ -78,16 +78,16 @@ class MediaRepository:
|
||||||
|
|
||||||
Thumbnailer.set_limits(self.max_image_pixels)
|
Thumbnailer.set_limits(self.max_image_pixels)
|
||||||
|
|
||||||
self.primary_base_path = hs.config.media_store_path # type: str
|
self.primary_base_path: str = hs.config.media_store_path
|
||||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
|
||||||
|
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||||
|
|
||||||
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
|
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
|
||||||
self.recently_accessed_locals = set() # type: Set[str]
|
self.recently_accessed_locals: Set[str] = set()
|
||||||
|
|
||||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
|
@ -711,7 +711,7 @@ class MediaRepository:
|
||||||
|
|
||||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||||
# they have the same dimensions of a scaled one.
|
# they have the same dimensions of a scaled one.
|
||||||
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
|
thumbnails: Dict[Tuple[int, int, str], str] = {}
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
for r_width, r_height, r_method, r_type in requirements:
|
||||||
if r_method == "crop":
|
if r_method == "crop":
|
||||||
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||||
|
|
|
@ -191,7 +191,7 @@ class MediaStorage:
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
for provider in self.storage_providers:
|
||||||
for path in paths:
|
for path in paths:
|
||||||
res = await provider.fetch(path, file_info) # type: Any
|
res: Any = await provider.fetch(path, file_info)
|
||||||
if res:
|
if res:
|
||||||
logger.debug("Streaming %s from %s", path, provider)
|
logger.debug("Streaming %s from %s", path, provider)
|
||||||
return res
|
return res
|
||||||
|
@ -233,7 +233,7 @@ class MediaStorage:
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
for provider in self.storage_providers:
|
||||||
res = await provider.fetch(path, file_info) # type: Any
|
res: Any = await provider.fetch(path, file_info)
|
||||||
if res:
|
if res:
|
||||||
with res:
|
with res:
|
||||||
consumer = BackgroundFileConsumer(
|
consumer = BackgroundFileConsumer(
|
||||||
|
|
|
@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
|
|
||||||
# memory cache mapping urls to an ObservableDeferred returning
|
# memory cache mapping urls to an ObservableDeferred returning
|
||||||
# JSON-encoded OG metadata
|
# JSON-encoded OG metadata
|
||||||
self._cache = ExpiringCache(
|
self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
|
||||||
cache_name="url_previews",
|
cache_name="url_previews",
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
# don't spider URLs more often than once an hour
|
# don't spider URLs more often than once an hour
|
||||||
expiry_ms=ONE_HOUR,
|
expiry_ms=ONE_HOUR,
|
||||||
) # type: ExpiringCache[str, ObservableDeferred]
|
)
|
||||||
|
|
||||||
if self._worker_run_media_background_jobs:
|
if self._worker_run_media_background_jobs:
|
||||||
self._cleaner_loop = self.clock.looping_call(
|
self._cleaner_loop = self.clock.looping_call(
|
||||||
|
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||||
|
|
||||||
# If this URL can be accessed via oEmbed, use that instead.
|
# If this URL can be accessed via oEmbed, use that instead.
|
||||||
url_to_download = url # type: Optional[str]
|
url_to_download: Optional[str] = url
|
||||||
oembed_url = self._get_oembed_url(url)
|
oembed_url = self._get_oembed_url(url)
|
||||||
if oembed_url:
|
if oembed_url:
|
||||||
# The result might be a new URL to download, or it might be HTML content.
|
# The result might be a new URL to download, or it might be HTML content.
|
||||||
|
@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
|
||||||
# "og:video:height" : "720",
|
# "og:video:height" : "720",
|
||||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||||
|
|
||||||
og = {} # type: Dict[str, Optional[str]]
|
og: Dict[str, Optional[str]] = {}
|
||||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||||
if "content" in tag.attrib:
|
if "content" in tag.attrib:
|
||||||
# if we've got more than 50 tags, someone is taking the piss
|
# if we've got more than 50 tags, someone is taking the piss
|
||||||
|
|
|
@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
|
||||||
errcode=Codes.TOO_LARGE,
|
errcode=Codes.TOO_LARGE,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
|
args: Dict[bytes, List[bytes]] = request.args # type: ignore
|
||||||
upload_name_bytes = parse_bytes_from_args(args, "filename")
|
upload_name_bytes = parse_bytes_from_args(args, "filename")
|
||||||
if upload_name_bytes:
|
if upload_name_bytes:
|
||||||
try:
|
try:
|
||||||
upload_name = upload_name_bytes.decode("utf8") # type: Optional[str]
|
upload_name: Optional[str] = upload_name_bytes.decode("utf8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
|
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
|
||||||
|
@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
|
||||||
# TODO(markjh): parse content-dispostion
|
# TODO(markjh): parse content-dispostion
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = request.content # type: IO # type: ignore
|
content: IO = request.content # type: ignore
|
||||||
content_uri = await self.media_repo.create_content(
|
content_uri = await self.media_repo.create_content(
|
||||||
media_type, upload_name, content, content_length, requester.user
|
media_type, upload_name, content, content_length, requester.user
|
||||||
)
|
)
|
||||||
|
|
|
@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
|
||||||
use_display_name = parse_boolean(request, "use_display_name", default=False)
|
use_display_name = parse_boolean(request, "use_display_name", default=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
emails_to_use = [
|
emails_to_use: List[str] = [
|
||||||
val.decode("utf-8") for val in request.args.get(b"use_email", [])
|
val.decode("utf-8") for val in request.args.get(b"use_email", [])
|
||||||
] # type: List[str]
|
]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise SynapseError(400, "Query parameter use_email must be utf-8")
|
raise SynapseError(400, "Query parameter use_email must be utf-8")
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
|
|
Loading…
Reference in a new issue