Use inline type hints in http/federation/, storage/ and util/ (#10381)

This commit is contained in:
Jonathan de Jong 2021-07-15 18:46:54 +02:00 committed by GitHub
parent 3acf85c85f
commit bdfde6dca1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
38 changed files with 149 additions and 161 deletions

1
changelog.d/10381.misc Normal file
View file

@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.

View file

@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]] _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache = TTLCache( _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
"had-valid-well-known"
) # type: TTLCache[bytes, bool]
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
@ -130,9 +128,10 @@ class WellKnownResolver:
# requests for the same server in parallel? # requests for the same server in parallel?
try: try:
with Measure(self._clock, "get_well_known"): with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known( result: Optional[bytes]
server_name cache_period: float
) # type: Optional[bytes], float
result, cache_period = await self._fetch_well_known(server_name)
except _FetchWellKnownFailure as e: except _FetchWellKnownFailure as e:
if prev_result and e.temporary: if prev_result and e.temporary:

View file

@ -92,14 +92,12 @@ class BackgroundUpdater:
self.db_pool = database self.db_pool = database
# if a background update is currently running, its name. # if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str] self._current_background_update: Optional[str] = None
self._background_update_performance = ( self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
{} self._background_update_handlers: Dict[
) # type: Dict[str, BackgroundUpdatePerformance] str, Callable[[JsonDict, int], Awaitable[int]]
self._background_update_handlers = ( ] = {}
{}
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False self._all_done = False
def start_doing_background_updates(self) -> None: def start_doing_background_updates(self) -> None:
@ -411,7 +409,7 @@ class BackgroundUpdater:
c.execute(sql) c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine): if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql # type: Optional[Callable[[Connection], None]] runner: Optional[Callable[[Connection], None]] = create_index_psql
elif psql_only: elif psql_only:
runner = None runner = None
else: else:

View file

@ -670,8 +670,8 @@ class DatabasePool:
Returns: Returns:
The result of func The result of func
""" """
after_callbacks = [] # type: List[_CallbackListEntry] after_callbacks: List[_CallbackListEntry] = []
exception_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks: List[_CallbackListEntry] = []
if not current_context(): if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc) logger.warning("Starting db txn '%s' from sentinel context", desc)
@ -1090,7 +1090,7 @@ class DatabasePool:
return False return False
# We didn't find any existing rows, so insert a new one # We didn't find any existing rows, so insert a new one
allvalues = {} # type: Dict[str, Any] allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(values) allvalues.update(values)
allvalues.update(insertion_values) allvalues.update(insertion_values)
@ -1121,7 +1121,7 @@ class DatabasePool:
values: The nonunique columns and their new values values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting insertion_values: additional key/values to use only when inserting
""" """
allvalues = {} # type: Dict[str, Any] allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(insertion_values or {}) allvalues.update(insertion_values or {})
@ -1257,7 +1257,7 @@ class DatabasePool:
value_values: A list of each row's value column values. value_values: A list of each row's value column values.
Ignored if value_names is empty. Ignored if value_names is empty.
""" """
allnames = [] # type: List[str] allnames: List[str] = []
allnames.extend(key_names) allnames.extend(key_names)
allnames.extend(value_names) allnames.extend(value_names)
@ -1566,7 +1566,7 @@ class DatabasePool:
""" """
keyvalues = keyvalues or {} keyvalues = keyvalues or {}
results = [] # type: List[Dict[str, Any]] results: List[Dict[str, Any]] = []
if not iterable: if not iterable:
return results return results
@ -1978,7 +1978,7 @@ class DatabasePool:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else "" where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
arg_list = [] # type: List[Any] arg_list: List[Any] = []
if filters: if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values()) arg_list += list(filters.values())

View file

@ -48,9 +48,7 @@ def _make_exclusive_regex(
] ]
if exclusive_user_regexes: if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_pattern = re.compile( exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
exclusive_user_regex
) # type: Optional[Pattern]
else: else:
# We handle this case specially otherwise the constructed regex # We handle this case specially otherwise the constructed regex
# will always match # will always match

View file

@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
txn.execute(sql, query_params) txn.execute(sql, query_params)
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn: for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices: if include_deleted_devices:
deleted_devices.remove((user_id, device_id)) deleted_devices.remove((user_id, device_id))

View file

@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) )
# Cache of event ID to list of auth event IDs and their depths. # Cache of event ID to list of auth event IDs and their depths.
self._event_auth_cache = LruCache( self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
500000, "_event_auth_cache", size_callback=len 500000, "_event_auth_cache", size_callback=len
) # type: LruCache[str, List[Tuple[str, int]]] )
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(event_ids) initial_events = set(event_ids)
# All the events that we've found that are reachable from the events. # All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str] seen_events: Set[str] = set()
# A map from chain ID to max sequence number of the given events. # A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int] event_chains: Dict[int, int] = {}
sql = """ sql = """
SELECT event_id, chain_id, sequence_number SELECT event_id, chain_id, sequence_number
@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
""" """
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int] chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains. # Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000): for batch in batch_iter(event_chains, 1000):
@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
initial_events = set(state_sets[0]).union(*state_sets[1:]) initial_events = set(state_sets[0]).union(*state_sets[1:])
# Map from event_id -> (chain ID, seq no) # Map from event_id -> (chain ID, seq no)
chain_info = {} # type: Dict[str, Tuple[int, int]] chain_info: Dict[str, Tuple[int, int]] = {}
# Map from chain ID -> seq no -> event Id # Map from chain ID -> seq no -> event Id
chain_to_event = {} # type: Dict[int, Dict[int, str]] chain_to_event: Dict[int, Dict[int, str]] = {}
# All the chains that we've found that are reachable from the state # All the chains that we've found that are reachable from the state
# sets. # sets.
seen_chains = set() # type: Set[int] seen_chains: Set[int] = set()
sql = """ sql = """
SELECT event_id, chain_id, sequence_number SELECT event_id, chain_id, sequence_number
@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Corresponds to `state_sets`, except as a map from chain ID to max # Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set. # sequence number reachable from the state set.
set_to_chain = [] # type: List[Dict[int, int]] set_to_chain: List[Dict[int, int]] = []
for state_set in state_sets: for state_set in state_sets:
chains = {} # type: Dict[int, int] chains: Dict[int, int] = {}
set_to_chain.append(chains) set_to_chain.append(chains)
for event_id in state_set: for event_id in state_set:
@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Mapping from chain ID to the range of sequence numbers that should be # Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database. # pulled from the database.
chain_to_gap = {} # type: Dict[int, Tuple[int, int]] chain_to_gap: Dict[int, Tuple[int, int]] = {}
for chain_id in seen_chains: for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
} }
# The sorted list of events whose auth chains we should walk. # The sorted list of events whose auth chains we should walk.
search = [] # type: List[Tuple[int, str]] search: List[Tuple[int, str]] = []
# We need to get the depth of the initial events for sorting purposes. # We need to get the depth of the initial events for sorting purposes.
sql = """ sql = """
@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
search.sort() search.sort()
# Map from event to its auth events # Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]] event_to_auth_events: Dict[str, Set[str]] = {}
base_sql = """ base_sql = """
SELECT a.event_id, auth_id, depth SELECT a.event_id, auth_id, depth

View file

@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# object because we might not have the same amount of rows in each of them. To do # object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to # this, we use a dict indexed on the user ID and room ID to make it easier to
# populate. # populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary] summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
for row in txn: for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary( summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2], unread_count=row[2],

View file

@ -109,10 +109,8 @@ class PersistEventsStore:
# Ideally we'd move these ID gens here, unfortunately some other ID # Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA. # generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen = ( self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self.store._backfill_id_gen self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# This should only exist on instances that are configured to write # This should only exist on instances that are configured to write
assert ( assert (
@ -221,7 +219,7 @@ class PersistEventsStore:
Returns: Returns:
Filtered event ids Filtered event ids
""" """
results = [] # type: List[str] results: List[str] = []
def _get_events_which_are_prevs_txn(txn, batch): def _get_events_which_are_prevs_txn(txn, batch):
sql = """ sql = """
@ -508,7 +506,7 @@ class PersistEventsStore:
""" """
# Map from event ID to chain ID/sequence number. # Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]] chain_map: Dict[str, Tuple[int, int]] = {}
# Set of event IDs to calculate chain ID/seq numbers for. # Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(event_to_room_id) events_to_calc_chain_id_for = set(event_to_room_id)
@ -817,8 +815,8 @@ class PersistEventsStore:
# new chain if the sequence number has already been allocated. # new chain if the sequence number has already been allocated.
# #
existing_chains = set() # type: Set[int] existing_chains: Set[int] = set()
tree = [] # type: List[Tuple[str, Optional[str]]] tree: List[Tuple[str, Optional[str]]] = []
# We need to do this in a topologically sorted order as we want to # We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before # generate chain IDs/sequence numbers of an event's auth events before
@ -848,7 +846,7 @@ class PersistEventsStore:
) )
txn.execute(sql % (clause,), args) txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
# Allocate the new events chain ID/sequence numbers. # Allocate the new events chain ID/sequence numbers.
# #
@ -858,8 +856,8 @@ class PersistEventsStore:
# number of new chain IDs in one call, replacing all temporary # number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs. # objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object] unallocated_chain_ids: Set[object] = set()
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
for event_id, auth_event_id in tree: for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID, # If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated # either from the existing `chain_map` or the newly generated
@ -870,7 +868,7 @@ class PersistEventsStore:
if not existing_chain_id: if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id] existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]] new_chain_tuple: Optional[Tuple[Any, int]] = None
if existing_chain_id: if existing_chain_id:
# We found a chain ID/sequence number candidate, check its # We found a chain ID/sequence number candidate, check its
# not already taken. # not already taken.
@ -897,9 +895,9 @@ class PersistEventsStore:
) )
# Map from potentially temporary chain ID to real chain ID # Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict( chain_id_to_allocated_map: Dict[Any, int] = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids) zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int] )
chain_id_to_allocated_map.update((c, c) for c in existing_chains) chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return { return {
@ -1175,9 +1173,9 @@ class PersistEventsStore:
Returns: Returns:
list[(EventBase, EventContext)]: filtered list list[(EventBase, EventContext)]: filtered list
""" """
new_events_and_contexts = ( new_events_and_contexts: OrderedDict[
OrderedDict() str, Tuple[EventBase, EventContext]
) # type: OrderedDict[str, Tuple[EventBase, EventContext]] ] = OrderedDict()
for event, context in events_and_contexts: for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id) prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context: if prev_event_context:
@ -1205,7 +1203,7 @@ class PersistEventsStore:
we are persisting we are persisting
backfilled (bool): True if the events were backfilled backfilled (bool): True if the events were backfilled
""" """
depth_updates = {} # type: Dict[str, int] depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id) txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@ -1885,7 +1883,7 @@ class PersistEventsStore:
), ),
) )
room_to_event_ids = {} # type: Dict[str, List[str]] room_to_event_ids: Dict[str, List[str]] = {}
for e, _ in events_and_contexts: for e, _ in events_and_contexts:
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
@ -2012,7 +2010,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events. Forward extremities are handled when we first start persisting the events.
""" """
events_by_room = {} # type: Dict[str, List[EventBase]] events_by_room: Dict[str, List[EventBase]] = {}
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)

View file

@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
event_to_types = {row[0]: (row[1], row[2]) for row in rows} event_to_types = {row[0]: (row[1], row[2]) for row in rows}
# Calculate the new last position we've processed up to. # Calculate the new last position we've processed up to.
new_last_depth = rows[-1][3] if rows else last_depth # type: int new_last_depth: int = rows[-1][3] if rows else last_depth
new_last_stream = rows[-1][4] if rows else last_stream # type: int new_last_stream: int = rows[-1][4] if rows else last_stream
new_last_room_id = rows[-1][5] if rows else "" # type: str new_last_room_id: str = rows[-1][5] if rows else ""
# Map from room_id to last depth/stream_ordering processed for the room, # Map from room_id to last depth/stream_ordering processed for the room,
# excluding the last room (which we're likely still processing). We also # excluding the last room (which we're likely still processing). We also
@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
retcols=("event_id", "auth_id"), retcols=("event_id", "auth_id"),
) )
event_to_auth_chain = {} # type: Dict[str, List[str]] event_to_auth_chain: Dict[str, List[str]] = {}
for row in auth_events: for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])

View file

@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all* # we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id. # the rows with that stream id.
rows = await self.db_pool.runInteraction( rows: List[Tuple] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas", "get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn, get_all_updated_current_state_deltas_txn,
) # type: List[Tuple] )
# if we've got fewer rows than the limit, we're good # if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count: if len(rows) < target_row_count:
@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
""" """
mapping = {} mapping = {}
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
for event in events: for event in events:
token_id = getattr(event.internal_metadata, "token_id", None) token_id = getattr(event.internal_metadata, "token_id", None)

View file

@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete") logger.info("[purge] looking for events to delete")
should_delete_expr = "state_key IS NULL" should_delete_expr = "state_key IS NULL"
should_delete_params = () # type: Tuple[Any, ...] should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events: if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?" should_delete_expr += " AND event_id NOT LIKE ?"

View file

@ -79,9 +79,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator( self._push_rules_stream_id_gen: Union[
db_conn, "push_rules_stream", "stream_id" StreamIdGenerator, SlavedIdTracker
) # type: Union[StreamIdGenerator, SlavedIdTracker] ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
else: else:
self._push_rules_stream_id_gen = SlavedIdTracker( self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id" db_conn, "push_rules_stream", "stream_id"

View file

@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items() items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items) where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items] # type: List[Union[str, int]] values: List[Union[str, int]] = [v for _, v in items]
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
# clause and values before we handle that. This seems to be only used in the "set password" handler. # clause and values before we handle that. This seems to be only used in the "set password" handler.

View file

@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
# stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
# then filtering the results. # then filtering the results.
if from_token.topological is not None: if from_token.topological is not None:
from_bound = ( from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
from_token.as_historical_tuple()
) # type: Tuple[Optional[int], int]
elif direction == "b": elif direction == "b":
from_bound = ( from_bound = (
None, None,
@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
from_token.stream, from_token.stream,
) )
to_bound = None # type: Optional[Tuple[Optional[int], int]] to_bound: Optional[Tuple[Optional[int], int]] = None
if to_token: if to_token:
if to_token.topological is not None: if to_token.topological is not None:
to_bound = to_token.as_historical_tuple() to_bound = to_token.as_historical_tuple()

View file

@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
) )
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]] tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows: for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {}) room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"]) room_tags[row["tag"]] = db_to_json(row["content"])

View file

@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
self, txn: LoggingTransaction, session_id: str, key: str, value: Any self, txn: LoggingTransaction, session_id: str, key: str, value: Any
): ):
# Get the current value. # Get the current value.
result = self.db_pool.simple_select_one_txn( result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
txn, txn,
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
retcols=("serverdict",), retcols=("serverdict",),
) # type: Dict[str, Any] # type: ignore )
# Update it and add it back to the database. # Update it and add it back to the database.
serverdict = db_to_json(result["serverdict"]) serverdict = db_to_json(result["serverdict"])

View file

@ -307,7 +307,7 @@ class EventsPersistenceStorage:
matched the transcation ID; the existing event is returned in such matched the transcation ID; the existing event is returned in such
a case. a case.
""" """
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
@ -384,7 +384,7 @@ class EventsPersistenceStorage:
A dictionary of event ID to event ID we didn't persist as we already A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID. had another event persisted with the same TXN ID.
""" """
replaced_events = {} # type: Dict[str, str] replaced_events: Dict[str, str] = {}
if not events_and_contexts: if not events_and_contexts:
return replaced_events return replaced_events
@ -440,16 +440,14 @@ class EventsPersistenceStorage:
# Set of remote users which were in rooms the server has left. We # Set of remote users which were in rooms the server has left. We
# should check if we still share any rooms and if not we mark their # should check if we still share any rooms and if not we mark their
# device lists as stale. # device lists as stale.
potentially_left_users = set() # type: Set[str] potentially_left_users: Set[str] = set()
if not backfilled: if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"): with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room. # Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then # We do this by working out what the new extremities are and then
# calculating the state from that. # calculating the state from that.
events_by_room = ( events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk: for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append( events_by_room.setdefault(event.room_id, []).append(
(event, context) (event, context)
@ -622,9 +620,9 @@ class EventsPersistenceStorage:
) )
# Remove any events which are prev_events of any existing events. # Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs( existing_prevs: Collection[
result str
) # type: Collection[str] ] = await self.persist_events_store._get_events_which_are_prevs(result)
result.difference_update(existing_prevs) result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev # Finally handle the case where the new events have soft-failed prev

View file

@ -256,7 +256,7 @@ def _setup_new_database(
for database in databases for database in databases
) )
directory_entries = [] # type: List[_DirectoryListing] directory_entries: List[_DirectoryListing] = []
for directory in directories: for directory in directories:
directory_entries.extend( directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name)) _DirectoryListing(file_name, os.path.join(directory, file_name))
@ -424,10 +424,10 @@ def _upgrade_existing_database(
directories.append(os.path.join(schema_path, database, "delta", str(v))) directories.append(os.path.join(schema_path, database, "delta", str(v)))
# Used to check if we have any duplicate file names # Used to check if we have any duplicate file names
file_name_counter = Counter() # type: CounterType[str] file_name_counter: CounterType[str] = Counter()
# Now find which directories have anything of interest. # Now find which directories have anything of interest.
directory_entries = [] # type: List[_DirectoryListing] directory_entries: List[_DirectoryListing] = []
for directory in directories: for directory in directories:
logger.debug("Looking for schema deltas in %s", directory) logger.debug("Looking for schema deltas in %s", directory)
try: try:

View file

@ -91,7 +91,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
type_dict = {} # type: Dict[str, Optional[Set[str]]] type_dict: Dict[str, Optional[Set[str]]] = {}
for typ, s in types: for typ, s in types:
if typ in type_dict: if typ in type_dict:
if type_dict[typ] is None: if type_dict[typ] is None:
@ -194,7 +194,7 @@ class StateFilter:
""" """
where_clause = "" where_clause = ""
where_args = [] # type: List[str] where_args: List[str] = []
if self.is_full(): if self.is_full():
return where_clause, where_args return where_clause, where_args

View file

@ -112,7 +112,7 @@ class StreamIdGenerator:
# insertion ordering will ensure its in the correct ordering. # insertion ordering will ensure its in the correct ordering.
# #
# The key and values are the same, but we never look at the values. # The key and values are the same, but we never look at the values.
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def get_next(self): def get_next(self):
""" """
@ -236,15 +236,15 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as # Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we # positive to make life easier for us, and simply negate the IDs when we
# return them. # return them.
self._current_positions = {} # type: Dict[str, int] self._current_positions: Dict[str, int] = {}
# Set of local IDs that we're still processing. The current position # Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty). # should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int] self._unfinished_ids: Set[int] = set()
# Set of local IDs that we've processed that are larger than the current # Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs. # position, due to there being smaller unpersisted IDs.
self._finished_ids = set() # type: Set[int] self._finished_ids: Set[int] = set()
# We track the max position where we know everything before has been # We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances # persisted. This is done by a) looking at the min across all instances
@ -265,7 +265,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = ( self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 1 min(self._current_positions.values()) if self._current_positions else 1
) )
self._known_persisted_positions = [] # type: List[int] self._known_persisted_positions: List[int] = []
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
@ -465,7 +465,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id) self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id) self._finished_ids.add(next_id)
new_cur = None # type: Optional[int] new_cur: Optional[int] = None
if self._unfinished_ids: if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the # If there are unfinished IDs then the new position will be the

View file

@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator):
get_next_id_txn; should return the curreent maximum id get_next_id_txn; should return the curreent maximum id
""" """
# the callback. this is cleared after it is called, so that it can be GCed. # the callback. this is cleared after it is called, so that it can be GCed.
self._callback = get_first_callback # type: Optional[GetFirstCallbackType] self._callback: Optional[GetFirstCallbackType] = get_first_callback
# The current max value, or None if we haven't looked in the DB yet. # The current max value, or None if we haven't looked in the DB yet.
self._current_max_id = None # type: Optional[int] self._current_max_id: Optional[int] = None
self._lock = threading.Lock() self._lock = threading.Lock()
def get_next_id_txn(self, txn: Cursor) -> int: def get_next_id_txn(self, txn: Cursor) -> int:
@ -274,7 +274,7 @@ def build_sequence_generator(
`check_consistency` details. `check_consistency` details.
""" """
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
else: else:
seq = LocalSequenceGenerator(get_first_callback) seq = LocalSequenceGenerator(get_first_callback)

View file

@ -257,7 +257,7 @@ class Linearizer:
max_count: The maximum number of concurrent accesses max_count: The maximum number of concurrent accesses
""" """
if name is None: if name is None:
self.name = id(self) # type: Union[str, int] self.name: Union[str, int] = id(self)
else: else:
self.name = name self.name = name
@ -269,7 +269,7 @@ class Linearizer:
self.max_count = max_count self.max_count = max_count
# key_to_defer is a map from the key to a _LinearizerEntry. # key_to_defer is a map from the key to a _LinearizerEntry.
self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
def is_queued(self, key: Hashable) -> bool: def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting""" """Checks whether there is a process queued up waiting"""
@ -409,10 +409,10 @@ class ReadWriteLock:
def __init__(self): def __init__(self):
# Latest readers queued # Latest readers queued
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
# Latest writer queued # Latest writer queued
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager: async def read(self, key: str) -> ContextManager:
new_defer = defer.Deferred() new_defer = defer.Deferred()

View file

@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]):
self._clock = clock self._clock = clock
# The set of keys currently being processed. # The set of keys currently being processed.
self._processing_keys = set() # type: Set[Hashable] self._processing_keys: Set[Hashable] = set()
# The currently pending batch of values by key, with a Deferred to call # The currently pending batch of values by key, with a Deferred to call
# with the result of the corresponding `_process_batch_callback` call. # with the result of the corresponding `_process_batch_callback` call.
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]] self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
# The function to call with batches of values. # The function to call with batches of values.
self._process_batch_callback = process_batch_callback self._process_batch_callback = process_batch_callback
@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]):
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values)) number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
self._number_in_flight_metric = number_in_flight.labels( self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
self._name
) # type: Gauge
async def add_to_queue(self, value: V, key: Hashable = ()) -> R: async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
"""Adds the value to the queue with the given key, returning the result """Adds the value to the queue with the given key, returning the result

View file

@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
TRACK_MEMORY_USAGE = False TRACK_MEMORY_USAGE = False
caches_by_name = {} # type: Dict[str, Sized] caches_by_name: Dict[str, Sized] = {}
collectors_by_name = {} # type: Dict[str, CacheMetric] collectors_by_name: Dict[str, "CacheMetric"] = {}
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View file

@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
f: The underlying function. Only one call to this function will be alive f: The underlying function. Only one call to this function will be alive
at once (per instance of CachedCall) at once (per instance of CachedCall)
""" """
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] self._callable: Optional[Callable[[], Awaitable[TV]]] = f
self._deferred = None # type: Optional[Deferred] self._deferred: Optional[Deferred] = None
self._result = None # type: Union[None, Failure, TV] self._result: Union[None, Failure, TV] = None
async def get(self) -> TV: async def get(self) -> TV:
"""Kick off the call if necessary, and return the result""" """Kick off the call if necessary, and return the result"""

View file

@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
# _pending_deferred_cache maps from the key value to a `CacheEntry` object. # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = ( self._pending_deferred_cache: Union[
cache_type() TreeCache, "MutableMapping[KT, CacheEntry]"
) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] ] = cache_type()
def metrics_cb(): def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than # cache is used for completed results and maps to the result itself, rather than
# a Deferred. # a Deferred.
self.cache = LruCache( self.cache: LruCache[KT, VT] = LruCache(
max_size=max_entries, max_size=max_entries,
cache_name=name, cache_name=name,
cache_type=cache_type, cache_type=cache_type,
size_callback=(lambda d: len(d) or 1) if iterable else None, size_callback=(lambda d: len(d) or 1) if iterable else None,
metrics_collection_callback=metrics_cb, metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config, apply_cache_factor_from_config=apply_cache_factor_from_config,
) # type: LruCache[KT, VT] )
self.thread = None # type: Optional[threading.Thread] self.thread: Optional[threading.Thread] = None
@property @property
def max_entries(self): def max_entries(self):

View file

@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]): class _CachedFunction(Generic[F]):
invalidate = None # type: Any invalidate: Any = None
invalidate_all = None # type: Any invalidate_all: Any = None
prefill = None # type: Any prefill: Any = None
cache = None # type: Any cache: Any = None
num_args = None # type: Any num_args: Any = None
__name__ = None # type: str __name__: str
# Note: This function signature is actually fiddled with by the synapse mypy # Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg. # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F __call__: F
class _CacheDescriptorBase: class _CacheDescriptorBase:
@ -115,8 +115,8 @@ class _CacheDescriptorBase:
class _LruCachedFunction(Generic[F]): class _LruCachedFunction(Generic[F]):
cache = None # type: LruCache[CacheKey, Any] cache: LruCache[CacheKey, Any]
__call__ = None # type: F __call__: F
def lru_cache( def lru_cache(
@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
self.max_entries = max_entries self.max_entries = max_entries
def __get__(self, obj, owner): def __get__(self, obj, owner):
cache = LruCache( cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__, cache_name=self.orig.__name__,
max_size=self.max_entries, max_size=self.max_entries,
) # type: LruCache[CacheKey, Any] )
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
sentinel = LruCacheDescriptor._Sentinel.sentinel sentinel = LruCacheDescriptor._Sentinel.sentinel
@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
def __get__(self, obj, owner): def __get__(self, obj, owner):
cache = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any] )
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name) cached_method = getattr(obj, self.cached_method_name)
cache = cached_method.cache # type: DeferredCache[CacheKey, Any] cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args num_args = cached_method.num_args
@functools.wraps(self.orig) @functools.wraps(self.orig)
@ -472,15 +472,15 @@ class _CacheContext:
Cache = Union[DeferredCache, LruCache] Cache = Union[DeferredCache, LruCache]
_cache_context_objects = ( _cache_context_objects: """WeakValueDictionary[
WeakValueDictionary() Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] ]""" = WeakValueDictionary()
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache self._cache = cache
self._cache_key = cache_key self._cache_key = cache_key
def invalidate(self): # type: () -> None def invalidate(self) -> None:
"""Invalidates the cache entry referred to by the context.""" """Invalidates the cache entry referred to by the context."""
self._cache.invalidate(self._cache_key) self._cache.invalidate(self._cache_key)

View file

@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
""" """
def __init__(self, name: str, max_entries: int = 1000): def __init__(self, name: str, max_entries: int = 1000):
self.cache = LruCache( self.cache: LruCache[KT, DictionaryEntry] = LruCache(
max_size=max_entries, cache_name=name, size_callback=len max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[KT, DictionaryEntry] )
self.name = name self.name = name
self.sequence = 0 self.sequence = 0
self.thread = None # type: Optional[threading.Thread] self.thread: Optional[threading.Thread] = None
def check_thread(self) -> None: def check_thread(self) -> None:
expected_thread = self.thread expected_thread = self.thread

View file

@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SENTINEL = object() # type: Any SENTINEL: Any = object()
T = TypeVar("T") T = TypeVar("T")
@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
self._expiry_ms = expiry_ms self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry] self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
self.iterable = iterable self.iterable = iterable

View file

@ -226,7 +226,7 @@ class _Node:
# footprint down. Storing `None` is free as its a singleton, while empty # footprint down. Storing `None` is free as its a singleton, while empty
# lists are 56 bytes (and empty sets are 216 bytes, if we did the naive # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
# thing and used sets). # thing and used sets).
self.callbacks = None # type: Optional[List[Callable[[], None]]] self.callbacks: Optional[List[Callable[[], None]]] = None
self.add_callbacks(callbacks) self.add_callbacks(callbacks)
@ -362,15 +362,15 @@ class LruCache(Generic[KT, VT]):
# register_cache might call our "set_cache_factor" callback; there's nothing to # register_cache might call our "set_cache_factor" callback; there's nothing to
# do yet when we get resized. # do yet when we get resized.
self._on_resize = None # type: Optional[Callable[[],None]] self._on_resize: Optional[Callable[[], None]] = None
if cache_name is not None: if cache_name is not None:
metrics = register_cache( metrics: Optional[CacheMetric] = register_cache(
"lru_cache", "lru_cache",
cache_name, cache_name,
self, self,
collect_callback=metrics_collection_callback, collect_callback=metrics_collection_callback,
) # type: Optional[CacheMetric] )
else: else:
metrics = None metrics = None

View file

@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
# This is poorly-named: it includes both complete and incomplete results. # This is poorly-named: it includes both complete and incomplete results.
# We keep complete results rather than switching to absolute values because # We keep complete results rather than switching to absolute values because
# that makes it easier to cache Failure results. # that makes it easier to cache Failure results.
self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred] self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
self.clock = clock self.clock = clock
self.timeout_sec = timeout_ms / 1000.0 self.timeout_sec = timeout_ms / 1000.0

View file

@ -45,10 +45,10 @@ class StreamChangeCache:
): ):
self._original_max_size = max_size self._original_max_size = max_size
self._max_size = math.floor(max_size) self._max_size = math.floor(max_size)
self._entity_to_key = {} # type: Dict[EntityType, int] self._entity_to_key: Dict[EntityType, int] = {}
# map from stream id to the a set of entities which changed at that stream id. # map from stream id to the a set of entities which changed at that stream id.
self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
# the earliest stream_pos for which we can reliably answer # the earliest stream_pos for which we can reliably answer
# get_all_entities_changed. In other words, one less than the earliest # get_all_entities_changed. In other words, one less than the earliest
@ -155,7 +155,7 @@ class StreamChangeCache:
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
return None return None
changed_entities = [] # type: List[EntityType] changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k]) changed_entities.extend(self._cache[k])

View file

@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SENTINEL = object() # type: Any SENTINEL: Any = object()
T = TypeVar("T") T = TypeVar("T")
KT = TypeVar("KT") KT = TypeVar("KT")
@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time): def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
# map from key to _CacheEntry # map from key to _CacheEntry
self._data = {} # type: Dict[KT, _CacheEntry] self._data: Dict[KT, _CacheEntry] = {}
# the _CacheEntries, sorted by expiry time # the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._expiry_list: SortedList[_CacheEntry] = SortedList()
self._timer = timer self._timer = timer

View file

@ -68,7 +68,7 @@ def sorted_topologically(
# This is implemented by Kahn's algorithm. # This is implemented by Kahn's algorithm.
degree_map = {node: 0 for node in nodes} degree_map = {node: 0 for node in nodes}
reverse_graph = {} # type: Dict[T, Set[T]] reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items(): for node, edges in graph.items():
if node not in degree_map: if node not in degree_map:

View file

@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
caveat in the macaroon, or if the caveat was not found in the macaroon. caveat in the macaroon, or if the caveat was not found in the macaroon.
""" """
prefix = key + " = " prefix = key + " = "
result = None # type: Optional[str] result: Optional[str] = None
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if not caveat.caveat_id.startswith(prefix): if not caveat.caveat_id.startswith(prefix):
continue continue

View file

@ -124,7 +124,7 @@ class Measure:
assert isinstance(curr_context, LoggingContext) assert isinstance(curr_context, LoggingContext)
parent_context = curr_context parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context) self._logging_context = LoggingContext(str(curr_context), parent_context)
self.start = None # type: Optional[int] self.start: Optional[int] = None
def __enter__(self) -> "Measure": def __enter__(self) -> "Measure":
if self.start is not None: if self.start is not None:

View file

@ -41,7 +41,7 @@ def do_patch():
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
start_context = current_context() start_context = current_context()
changes = [] # type: List[str] changes: List[str] = []
orig = orig_inline_callbacks(_check_yield_points(f, changes)) orig = orig_inline_callbacks(_check_yield_points(f, changes))
try: try:
@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
gen = f(*args, **kwargs) gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno last_yield_line_no = gen.gi_frame.f_lineno
result = None # type: Any result: Any = None
while True: while True:
expected_context = current_context() expected_context = current_context()