forked from MirrorHub/synapse
Use inline type hints in http/federation/
, storage/
and util/
(#10381)
This commit is contained in:
parent
3acf85c85f
commit
bdfde6dca1
38 changed files with 149 additions and 161 deletions
1
changelog.d/10381.misc
Normal file
1
changelog.d/10381.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert internal type variable syntax to reflect wider ecosystem use.
|
|
@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
|
_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
|
||||||
_had_valid_well_known_cache = TTLCache(
|
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
|
||||||
"had-valid-well-known"
|
|
||||||
) # type: TTLCache[bytes, bool]
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
|
@ -130,9 +128,10 @@ class WellKnownResolver:
|
||||||
# requests for the same server in parallel?
|
# requests for the same server in parallel?
|
||||||
try:
|
try:
|
||||||
with Measure(self._clock, "get_well_known"):
|
with Measure(self._clock, "get_well_known"):
|
||||||
result, cache_period = await self._fetch_well_known(
|
result: Optional[bytes]
|
||||||
server_name
|
cache_period: float
|
||||||
) # type: Optional[bytes], float
|
|
||||||
|
result, cache_period = await self._fetch_well_known(server_name)
|
||||||
|
|
||||||
except _FetchWellKnownFailure as e:
|
except _FetchWellKnownFailure as e:
|
||||||
if prev_result and e.temporary:
|
if prev_result and e.temporary:
|
||||||
|
|
|
@ -92,14 +92,12 @@ class BackgroundUpdater:
|
||||||
self.db_pool = database
|
self.db_pool = database
|
||||||
|
|
||||||
# if a background update is currently running, its name.
|
# if a background update is currently running, its name.
|
||||||
self._current_background_update = None # type: Optional[str]
|
self._current_background_update: Optional[str] = None
|
||||||
|
|
||||||
self._background_update_performance = (
|
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
|
||||||
{}
|
self._background_update_handlers: Dict[
|
||||||
) # type: Dict[str, BackgroundUpdatePerformance]
|
str, Callable[[JsonDict, int], Awaitable[int]]
|
||||||
self._background_update_handlers = (
|
] = {}
|
||||||
{}
|
|
||||||
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
|
|
||||||
self._all_done = False
|
self._all_done = False
|
||||||
|
|
||||||
def start_doing_background_updates(self) -> None:
|
def start_doing_background_updates(self) -> None:
|
||||||
|
@ -411,7 +409,7 @@ class BackgroundUpdater:
|
||||||
c.execute(sql)
|
c.execute(sql)
|
||||||
|
|
||||||
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
||||||
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
|
runner: Optional[Callable[[Connection], None]] = create_index_psql
|
||||||
elif psql_only:
|
elif psql_only:
|
||||||
runner = None
|
runner = None
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -670,8 +670,8 @@ class DatabasePool:
|
||||||
Returns:
|
Returns:
|
||||||
The result of func
|
The result of func
|
||||||
"""
|
"""
|
||||||
after_callbacks = [] # type: List[_CallbackListEntry]
|
after_callbacks: List[_CallbackListEntry] = []
|
||||||
exception_callbacks = [] # type: List[_CallbackListEntry]
|
exception_callbacks: List[_CallbackListEntry] = []
|
||||||
|
|
||||||
if not current_context():
|
if not current_context():
|
||||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||||
|
@ -1090,7 +1090,7 @@ class DatabasePool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# We didn't find any existing rows, so insert a new one
|
# We didn't find any existing rows, so insert a new one
|
||||||
allvalues = {} # type: Dict[str, Any]
|
allvalues: Dict[str, Any] = {}
|
||||||
allvalues.update(keyvalues)
|
allvalues.update(keyvalues)
|
||||||
allvalues.update(values)
|
allvalues.update(values)
|
||||||
allvalues.update(insertion_values)
|
allvalues.update(insertion_values)
|
||||||
|
@ -1121,7 +1121,7 @@ class DatabasePool:
|
||||||
values: The nonunique columns and their new values
|
values: The nonunique columns and their new values
|
||||||
insertion_values: additional key/values to use only when inserting
|
insertion_values: additional key/values to use only when inserting
|
||||||
"""
|
"""
|
||||||
allvalues = {} # type: Dict[str, Any]
|
allvalues: Dict[str, Any] = {}
|
||||||
allvalues.update(keyvalues)
|
allvalues.update(keyvalues)
|
||||||
allvalues.update(insertion_values or {})
|
allvalues.update(insertion_values or {})
|
||||||
|
|
||||||
|
@ -1257,7 +1257,7 @@ class DatabasePool:
|
||||||
value_values: A list of each row's value column values.
|
value_values: A list of each row's value column values.
|
||||||
Ignored if value_names is empty.
|
Ignored if value_names is empty.
|
||||||
"""
|
"""
|
||||||
allnames = [] # type: List[str]
|
allnames: List[str] = []
|
||||||
allnames.extend(key_names)
|
allnames.extend(key_names)
|
||||||
allnames.extend(value_names)
|
allnames.extend(value_names)
|
||||||
|
|
||||||
|
@ -1566,7 +1566,7 @@ class DatabasePool:
|
||||||
"""
|
"""
|
||||||
keyvalues = keyvalues or {}
|
keyvalues = keyvalues or {}
|
||||||
|
|
||||||
results = [] # type: List[Dict[str, Any]]
|
results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
if not iterable:
|
if not iterable:
|
||||||
return results
|
return results
|
||||||
|
@ -1978,7 +1978,7 @@ class DatabasePool:
|
||||||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||||
|
|
||||||
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
|
where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
|
||||||
arg_list = [] # type: List[Any]
|
arg_list: List[Any] = []
|
||||||
if filters:
|
if filters:
|
||||||
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
||||||
arg_list += list(filters.values())
|
arg_list += list(filters.values())
|
||||||
|
|
|
@ -48,9 +48,7 @@ def _make_exclusive_regex(
|
||||||
]
|
]
|
||||||
if exclusive_user_regexes:
|
if exclusive_user_regexes:
|
||||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||||
exclusive_user_pattern = re.compile(
|
exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
|
||||||
exclusive_user_regex
|
|
||||||
) # type: Optional[Pattern]
|
|
||||||
else:
|
else:
|
||||||
# We handle this case specially otherwise the constructed regex
|
# We handle this case specially otherwise the constructed regex
|
||||||
# will always match
|
# will always match
|
||||||
|
|
|
@ -247,7 +247,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
||||||
|
|
||||||
txn.execute(sql, query_params)
|
txn.execute(sql, query_params)
|
||||||
|
|
||||||
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
|
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
|
||||||
for (user_id, device_id, display_name, key_json) in txn:
|
for (user_id, device_id, display_name, key_json) in txn:
|
||||||
if include_deleted_devices:
|
if include_deleted_devices:
|
||||||
deleted_devices.remove((user_id, device_id))
|
deleted_devices.remove((user_id, device_id))
|
||||||
|
|
|
@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cache of event ID to list of auth event IDs and their depths.
|
# Cache of event ID to list of auth event IDs and their depths.
|
||||||
self._event_auth_cache = LruCache(
|
self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
|
||||||
500000, "_event_auth_cache", size_callback=len
|
500000, "_event_auth_cache", size_callback=len
|
||||||
) # type: LruCache[str, List[Tuple[str, int]]]
|
)
|
||||||
|
|
||||||
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
|
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
|
||||||
|
|
||||||
|
@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
initial_events = set(event_ids)
|
initial_events = set(event_ids)
|
||||||
|
|
||||||
# All the events that we've found that are reachable from the events.
|
# All the events that we've found that are reachable from the events.
|
||||||
seen_events = set() # type: Set[str]
|
seen_events: Set[str] = set()
|
||||||
|
|
||||||
# A map from chain ID to max sequence number of the given events.
|
# A map from chain ID to max sequence number of the given events.
|
||||||
event_chains = {} # type: Dict[int, int]
|
event_chains: Dict[int, int] = {}
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT event_id, chain_id, sequence_number
|
SELECT event_id, chain_id, sequence_number
|
||||||
|
@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# A map from chain ID to max sequence number *reachable* from any event ID.
|
# A map from chain ID to max sequence number *reachable* from any event ID.
|
||||||
chains = {} # type: Dict[int, int]
|
chains: Dict[int, int] = {}
|
||||||
|
|
||||||
# Add all linked chains reachable from initial set of chains.
|
# Add all linked chains reachable from initial set of chains.
|
||||||
for batch in batch_iter(event_chains, 1000):
|
for batch in batch_iter(event_chains, 1000):
|
||||||
|
@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
initial_events = set(state_sets[0]).union(*state_sets[1:])
|
initial_events = set(state_sets[0]).union(*state_sets[1:])
|
||||||
|
|
||||||
# Map from event_id -> (chain ID, seq no)
|
# Map from event_id -> (chain ID, seq no)
|
||||||
chain_info = {} # type: Dict[str, Tuple[int, int]]
|
chain_info: Dict[str, Tuple[int, int]] = {}
|
||||||
|
|
||||||
# Map from chain ID -> seq no -> event Id
|
# Map from chain ID -> seq no -> event Id
|
||||||
chain_to_event = {} # type: Dict[int, Dict[int, str]]
|
chain_to_event: Dict[int, Dict[int, str]] = {}
|
||||||
|
|
||||||
# All the chains that we've found that are reachable from the state
|
# All the chains that we've found that are reachable from the state
|
||||||
# sets.
|
# sets.
|
||||||
seen_chains = set() # type: Set[int]
|
seen_chains: Set[int] = set()
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT event_id, chain_id, sequence_number
|
SELECT event_id, chain_id, sequence_number
|
||||||
|
@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
|
|
||||||
# Corresponds to `state_sets`, except as a map from chain ID to max
|
# Corresponds to `state_sets`, except as a map from chain ID to max
|
||||||
# sequence number reachable from the state set.
|
# sequence number reachable from the state set.
|
||||||
set_to_chain = [] # type: List[Dict[int, int]]
|
set_to_chain: List[Dict[int, int]] = []
|
||||||
for state_set in state_sets:
|
for state_set in state_sets:
|
||||||
chains = {} # type: Dict[int, int]
|
chains: Dict[int, int] = {}
|
||||||
set_to_chain.append(chains)
|
set_to_chain.append(chains)
|
||||||
|
|
||||||
for event_id in state_set:
|
for event_id in state_set:
|
||||||
|
@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
|
|
||||||
# Mapping from chain ID to the range of sequence numbers that should be
|
# Mapping from chain ID to the range of sequence numbers that should be
|
||||||
# pulled from the database.
|
# pulled from the database.
|
||||||
chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
|
chain_to_gap: Dict[int, Tuple[int, int]] = {}
|
||||||
|
|
||||||
for chain_id in seen_chains:
|
for chain_id in seen_chains:
|
||||||
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
|
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||||
|
@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
}
|
}
|
||||||
|
|
||||||
# The sorted list of events whose auth chains we should walk.
|
# The sorted list of events whose auth chains we should walk.
|
||||||
search = [] # type: List[Tuple[int, str]]
|
search: List[Tuple[int, str]] = []
|
||||||
|
|
||||||
# We need to get the depth of the initial events for sorting purposes.
|
# We need to get the depth of the initial events for sorting purposes.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||||
search.sort()
|
search.sort()
|
||||||
|
|
||||||
# Map from event to its auth events
|
# Map from event to its auth events
|
||||||
event_to_auth_events = {} # type: Dict[str, Set[str]]
|
event_to_auth_events: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
base_sql = """
|
base_sql = """
|
||||||
SELECT a.event_id, auth_id, depth
|
SELECT a.event_id, auth_id, depth
|
||||||
|
|
|
@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
# object because we might not have the same amount of rows in each of them. To do
|
# object because we might not have the same amount of rows in each of them. To do
|
||||||
# this, we use a dict indexed on the user ID and room ID to make it easier to
|
# this, we use a dict indexed on the user ID and room ID to make it easier to
|
||||||
# populate.
|
# populate.
|
||||||
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
|
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
|
||||||
for row in txn:
|
for row in txn:
|
||||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||||
unread_count=row[2],
|
unread_count=row[2],
|
||||||
|
|
|
@ -109,10 +109,8 @@ class PersistEventsStore:
|
||||||
|
|
||||||
# Ideally we'd move these ID gens here, unfortunately some other ID
|
# Ideally we'd move these ID gens here, unfortunately some other ID
|
||||||
# generators are chained off them so doing so is a bit of a PITA.
|
# generators are chained off them so doing so is a bit of a PITA.
|
||||||
self._backfill_id_gen = (
|
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
|
||||||
self.store._backfill_id_gen
|
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
|
||||||
) # type: MultiWriterIdGenerator
|
|
||||||
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
|
|
||||||
|
|
||||||
# This should only exist on instances that are configured to write
|
# This should only exist on instances that are configured to write
|
||||||
assert (
|
assert (
|
||||||
|
@ -221,7 +219,7 @@ class PersistEventsStore:
|
||||||
Returns:
|
Returns:
|
||||||
Filtered event ids
|
Filtered event ids
|
||||||
"""
|
"""
|
||||||
results = [] # type: List[str]
|
results: List[str] = []
|
||||||
|
|
||||||
def _get_events_which_are_prevs_txn(txn, batch):
|
def _get_events_which_are_prevs_txn(txn, batch):
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -508,7 +506,7 @@ class PersistEventsStore:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Map from event ID to chain ID/sequence number.
|
# Map from event ID to chain ID/sequence number.
|
||||||
chain_map = {} # type: Dict[str, Tuple[int, int]]
|
chain_map: Dict[str, Tuple[int, int]] = {}
|
||||||
|
|
||||||
# Set of event IDs to calculate chain ID/seq numbers for.
|
# Set of event IDs to calculate chain ID/seq numbers for.
|
||||||
events_to_calc_chain_id_for = set(event_to_room_id)
|
events_to_calc_chain_id_for = set(event_to_room_id)
|
||||||
|
@ -817,8 +815,8 @@ class PersistEventsStore:
|
||||||
# new chain if the sequence number has already been allocated.
|
# new chain if the sequence number has already been allocated.
|
||||||
#
|
#
|
||||||
|
|
||||||
existing_chains = set() # type: Set[int]
|
existing_chains: Set[int] = set()
|
||||||
tree = [] # type: List[Tuple[str, Optional[str]]]
|
tree: List[Tuple[str, Optional[str]]] = []
|
||||||
|
|
||||||
# We need to do this in a topologically sorted order as we want to
|
# We need to do this in a topologically sorted order as we want to
|
||||||
# generate chain IDs/sequence numbers of an event's auth events before
|
# generate chain IDs/sequence numbers of an event's auth events before
|
||||||
|
@ -848,7 +846,7 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
txn.execute(sql % (clause,), args)
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
|
chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
|
||||||
|
|
||||||
# Allocate the new events chain ID/sequence numbers.
|
# Allocate the new events chain ID/sequence numbers.
|
||||||
#
|
#
|
||||||
|
@ -858,8 +856,8 @@ class PersistEventsStore:
|
||||||
# number of new chain IDs in one call, replacing all temporary
|
# number of new chain IDs in one call, replacing all temporary
|
||||||
# objects with real allocated chain IDs.
|
# objects with real allocated chain IDs.
|
||||||
|
|
||||||
unallocated_chain_ids = set() # type: Set[object]
|
unallocated_chain_ids: Set[object] = set()
|
||||||
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
|
new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
|
||||||
for event_id, auth_event_id in tree:
|
for event_id, auth_event_id in tree:
|
||||||
# If we reference an auth_event_id we fetch the allocated chain ID,
|
# If we reference an auth_event_id we fetch the allocated chain ID,
|
||||||
# either from the existing `chain_map` or the newly generated
|
# either from the existing `chain_map` or the newly generated
|
||||||
|
@ -870,7 +868,7 @@ class PersistEventsStore:
|
||||||
if not existing_chain_id:
|
if not existing_chain_id:
|
||||||
existing_chain_id = chain_map[auth_event_id]
|
existing_chain_id = chain_map[auth_event_id]
|
||||||
|
|
||||||
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
|
new_chain_tuple: Optional[Tuple[Any, int]] = None
|
||||||
if existing_chain_id:
|
if existing_chain_id:
|
||||||
# We found a chain ID/sequence number candidate, check its
|
# We found a chain ID/sequence number candidate, check its
|
||||||
# not already taken.
|
# not already taken.
|
||||||
|
@ -897,9 +895,9 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map from potentially temporary chain ID to real chain ID
|
# Map from potentially temporary chain ID to real chain ID
|
||||||
chain_id_to_allocated_map = dict(
|
chain_id_to_allocated_map: Dict[Any, int] = dict(
|
||||||
zip(unallocated_chain_ids, newly_allocated_chain_ids)
|
zip(unallocated_chain_ids, newly_allocated_chain_ids)
|
||||||
) # type: Dict[Any, int]
|
)
|
||||||
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
|
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -1175,9 +1173,9 @@ class PersistEventsStore:
|
||||||
Returns:
|
Returns:
|
||||||
list[(EventBase, EventContext)]: filtered list
|
list[(EventBase, EventContext)]: filtered list
|
||||||
"""
|
"""
|
||||||
new_events_and_contexts = (
|
new_events_and_contexts: OrderedDict[
|
||||||
OrderedDict()
|
str, Tuple[EventBase, EventContext]
|
||||||
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
|
] = OrderedDict()
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
prev_event_context = new_events_and_contexts.get(event.event_id)
|
prev_event_context = new_events_and_contexts.get(event.event_id)
|
||||||
if prev_event_context:
|
if prev_event_context:
|
||||||
|
@ -1205,7 +1203,7 @@ class PersistEventsStore:
|
||||||
we are persisting
|
we are persisting
|
||||||
backfilled (bool): True if the events were backfilled
|
backfilled (bool): True if the events were backfilled
|
||||||
"""
|
"""
|
||||||
depth_updates = {} # type: Dict[str, int]
|
depth_updates: Dict[str, int] = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
||||||
|
@ -1885,7 +1883,7 @@ class PersistEventsStore:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
room_to_event_ids = {} # type: Dict[str, List[str]]
|
room_to_event_ids: Dict[str, List[str]] = {}
|
||||||
for e, _ in events_and_contexts:
|
for e, _ in events_and_contexts:
|
||||||
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
|
room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
|
||||||
|
|
||||||
|
@ -2012,7 +2010,7 @@ class PersistEventsStore:
|
||||||
|
|
||||||
Forward extremities are handled when we first start persisting the events.
|
Forward extremities are handled when we first start persisting the events.
|
||||||
"""
|
"""
|
||||||
events_by_room = {} # type: Dict[str, List[EventBase]]
|
events_by_room: Dict[str, List[EventBase]] = {}
|
||||||
for ev in events:
|
for ev in events:
|
||||||
events_by_room.setdefault(ev.room_id, []).append(ev)
|
events_by_room.setdefault(ev.room_id, []).append(ev)
|
||||||
|
|
||||||
|
|
|
@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
|
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
|
||||||
|
|
||||||
# Calculate the new last position we've processed up to.
|
# Calculate the new last position we've processed up to.
|
||||||
new_last_depth = rows[-1][3] if rows else last_depth # type: int
|
new_last_depth: int = rows[-1][3] if rows else last_depth
|
||||||
new_last_stream = rows[-1][4] if rows else last_stream # type: int
|
new_last_stream: int = rows[-1][4] if rows else last_stream
|
||||||
new_last_room_id = rows[-1][5] if rows else "" # type: str
|
new_last_room_id: str = rows[-1][5] if rows else ""
|
||||||
|
|
||||||
# Map from room_id to last depth/stream_ordering processed for the room,
|
# Map from room_id to last depth/stream_ordering processed for the room,
|
||||||
# excluding the last room (which we're likely still processing). We also
|
# excluding the last room (which we're likely still processing). We also
|
||||||
|
@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
retcols=("event_id", "auth_id"),
|
retcols=("event_id", "auth_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
event_to_auth_chain = {} # type: Dict[str, List[str]]
|
event_to_auth_chain: Dict[str, List[str]] = {}
|
||||||
for row in auth_events:
|
for row in auth_events:
|
||||||
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
|
event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
|
||||||
|
|
||||||
|
|
|
@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
# we need to make sure that, for every stream id in the results, we get *all*
|
# we need to make sure that, for every stream id in the results, we get *all*
|
||||||
# the rows with that stream id.
|
# the rows with that stream id.
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction(
|
rows: List[Tuple] = await self.db_pool.runInteraction(
|
||||||
"get_all_updated_current_state_deltas",
|
"get_all_updated_current_state_deltas",
|
||||||
get_all_updated_current_state_deltas_txn,
|
get_all_updated_current_state_deltas_txn,
|
||||||
) # type: List[Tuple]
|
)
|
||||||
|
|
||||||
# if we've got fewer rows than the limit, we're good
|
# if we've got fewer rows than the limit, we're good
|
||||||
if len(rows) < target_row_count:
|
if len(rows) < target_row_count:
|
||||||
|
@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
|
txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
token_id = getattr(event.internal_metadata, "token_id", None)
|
token_id = getattr(event.internal_metadata, "token_id", None)
|
||||||
|
|
|
@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||||
logger.info("[purge] looking for events to delete")
|
logger.info("[purge] looking for events to delete")
|
||||||
|
|
||||||
should_delete_expr = "state_key IS NULL"
|
should_delete_expr = "state_key IS NULL"
|
||||||
should_delete_params = () # type: Tuple[Any, ...]
|
should_delete_params: Tuple[Any, ...] = ()
|
||||||
if not delete_local_events:
|
if not delete_local_events:
|
||||||
should_delete_expr += " AND event_id NOT LIKE ?"
|
should_delete_expr += " AND event_id NOT LIKE ?"
|
||||||
|
|
||||||
|
|
|
@ -79,9 +79,9 @@ class PushRulesWorkerStore(
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
self._push_rules_stream_id_gen: Union[
|
||||||
db_conn, "push_rules_stream", "stream_id"
|
StreamIdGenerator, SlavedIdTracker
|
||||||
) # type: Union[StreamIdGenerator, SlavedIdTracker]
|
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
|
||||||
else:
|
else:
|
||||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||||
db_conn, "push_rules_stream", "stream_id"
|
db_conn, "push_rules_stream", "stream_id"
|
||||||
|
|
|
@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
|
||||||
items = keyvalues.items()
|
items = keyvalues.items()
|
||||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||||
values = [v for _, v in items] # type: List[Union[str, int]]
|
values: List[Union[str, int]] = [v for _, v in items]
|
||||||
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||||
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||||
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||||
|
|
|
@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||||
# stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
|
# stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
|
||||||
# then filtering the results.
|
# then filtering the results.
|
||||||
if from_token.topological is not None:
|
if from_token.topological is not None:
|
||||||
from_bound = (
|
from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
|
||||||
from_token.as_historical_tuple()
|
|
||||||
) # type: Tuple[Optional[int], int]
|
|
||||||
elif direction == "b":
|
elif direction == "b":
|
||||||
from_bound = (
|
from_bound = (
|
||||||
None,
|
None,
|
||||||
|
@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||||
from_token.stream,
|
from_token.stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_bound = None # type: Optional[Tuple[Optional[int], int]]
|
to_bound: Optional[Tuple[Optional[int], int]] = None
|
||||||
if to_token:
|
if to_token:
|
||||||
if to_token.topological is not None:
|
if to_token.topological is not None:
|
||||||
to_bound = to_token.as_historical_tuple()
|
to_bound = to_token.as_historical_tuple()
|
||||||
|
|
|
@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
|
tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
||||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
room_tags[row["tag"]] = db_to_json(row["content"])
|
||||||
|
|
|
@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||||
):
|
):
|
||||||
# Get the current value.
|
# Get the current value.
|
||||||
result = self.db_pool.simple_select_one_txn(
|
result: Dict[str, Any] = self.db_pool.simple_select_one_txn( # type: ignore
|
||||||
txn,
|
txn,
|
||||||
table="ui_auth_sessions",
|
table="ui_auth_sessions",
|
||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
retcols=("serverdict",),
|
retcols=("serverdict",),
|
||||||
) # type: Dict[str, Any] # type: ignore
|
)
|
||||||
|
|
||||||
# Update it and add it back to the database.
|
# Update it and add it back to the database.
|
||||||
serverdict = db_to_json(result["serverdict"])
|
serverdict = db_to_json(result["serverdict"])
|
||||||
|
|
|
@ -307,7 +307,7 @@ class EventsPersistenceStorage:
|
||||||
matched the transcation ID; the existing event is returned in such
|
matched the transcation ID; the existing event is returned in such
|
||||||
a case.
|
a case.
|
||||||
"""
|
"""
|
||||||
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
|
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
|
||||||
for event, ctx in events_and_contexts:
|
for event, ctx in events_and_contexts:
|
||||||
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ class EventsPersistenceStorage:
|
||||||
A dictionary of event ID to event ID we didn't persist as we already
|
A dictionary of event ID to event ID we didn't persist as we already
|
||||||
had another event persisted with the same TXN ID.
|
had another event persisted with the same TXN ID.
|
||||||
"""
|
"""
|
||||||
replaced_events = {} # type: Dict[str, str]
|
replaced_events: Dict[str, str] = {}
|
||||||
if not events_and_contexts:
|
if not events_and_contexts:
|
||||||
return replaced_events
|
return replaced_events
|
||||||
|
|
||||||
|
@ -440,16 +440,14 @@ class EventsPersistenceStorage:
|
||||||
# Set of remote users which were in rooms the server has left. We
|
# Set of remote users which were in rooms the server has left. We
|
||||||
# should check if we still share any rooms and if not we mark their
|
# should check if we still share any rooms and if not we mark their
|
||||||
# device lists as stale.
|
# device lists as stale.
|
||||||
potentially_left_users = set() # type: Set[str]
|
potentially_left_users: Set[str] = set()
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
with Measure(self._clock, "_calculate_state_and_extrem"):
|
with Measure(self._clock, "_calculate_state_and_extrem"):
|
||||||
# Work out the new "current state" for each room.
|
# Work out the new "current state" for each room.
|
||||||
# We do this by working out what the new extremities are and then
|
# We do this by working out what the new extremities are and then
|
||||||
# calculating the state from that.
|
# calculating the state from that.
|
||||||
events_by_room = (
|
events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
|
||||||
{}
|
|
||||||
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
|
|
||||||
for event, context in chunk:
|
for event, context in chunk:
|
||||||
events_by_room.setdefault(event.room_id, []).append(
|
events_by_room.setdefault(event.room_id, []).append(
|
||||||
(event, context)
|
(event, context)
|
||||||
|
@ -622,9 +620,9 @@ class EventsPersistenceStorage:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove any events which are prev_events of any existing events.
|
# Remove any events which are prev_events of any existing events.
|
||||||
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
|
existing_prevs: Collection[
|
||||||
result
|
str
|
||||||
) # type: Collection[str]
|
] = await self.persist_events_store._get_events_which_are_prevs(result)
|
||||||
result.difference_update(existing_prevs)
|
result.difference_update(existing_prevs)
|
||||||
|
|
||||||
# Finally handle the case where the new events have soft-failed prev
|
# Finally handle the case where the new events have soft-failed prev
|
||||||
|
|
|
@ -256,7 +256,7 @@ def _setup_new_database(
|
||||||
for database in databases
|
for database in databases
|
||||||
)
|
)
|
||||||
|
|
||||||
directory_entries = [] # type: List[_DirectoryListing]
|
directory_entries: List[_DirectoryListing] = []
|
||||||
for directory in directories:
|
for directory in directories:
|
||||||
directory_entries.extend(
|
directory_entries.extend(
|
||||||
_DirectoryListing(file_name, os.path.join(directory, file_name))
|
_DirectoryListing(file_name, os.path.join(directory, file_name))
|
||||||
|
@ -424,10 +424,10 @@ def _upgrade_existing_database(
|
||||||
directories.append(os.path.join(schema_path, database, "delta", str(v)))
|
directories.append(os.path.join(schema_path, database, "delta", str(v)))
|
||||||
|
|
||||||
# Used to check if we have any duplicate file names
|
# Used to check if we have any duplicate file names
|
||||||
file_name_counter = Counter() # type: CounterType[str]
|
file_name_counter: CounterType[str] = Counter()
|
||||||
|
|
||||||
# Now find which directories have anything of interest.
|
# Now find which directories have anything of interest.
|
||||||
directory_entries = [] # type: List[_DirectoryListing]
|
directory_entries: List[_DirectoryListing] = []
|
||||||
for directory in directories:
|
for directory in directories:
|
||||||
logger.debug("Looking for schema deltas in %s", directory)
|
logger.debug("Looking for schema deltas in %s", directory)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -91,7 +91,7 @@ class StateFilter:
|
||||||
Returns:
|
Returns:
|
||||||
The new state filter.
|
The new state filter.
|
||||||
"""
|
"""
|
||||||
type_dict = {} # type: Dict[str, Optional[Set[str]]]
|
type_dict: Dict[str, Optional[Set[str]]] = {}
|
||||||
for typ, s in types:
|
for typ, s in types:
|
||||||
if typ in type_dict:
|
if typ in type_dict:
|
||||||
if type_dict[typ] is None:
|
if type_dict[typ] is None:
|
||||||
|
@ -194,7 +194,7 @@ class StateFilter:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
where_args = [] # type: List[str]
|
where_args: List[str] = []
|
||||||
|
|
||||||
if self.is_full():
|
if self.is_full():
|
||||||
return where_clause, where_args
|
return where_clause, where_args
|
||||||
|
|
|
@ -112,7 +112,7 @@ class StreamIdGenerator:
|
||||||
# insertion ordering will ensure its in the correct ordering.
|
# insertion ordering will ensure its in the correct ordering.
|
||||||
#
|
#
|
||||||
# The key and values are the same, but we never look at the values.
|
# The key and values are the same, but we never look at the values.
|
||||||
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
|
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||||
|
|
||||||
def get_next(self):
|
def get_next(self):
|
||||||
"""
|
"""
|
||||||
|
@ -236,15 +236,15 @@ class MultiWriterIdGenerator:
|
||||||
# Note: If we are a negative stream then we still store all the IDs as
|
# Note: If we are a negative stream then we still store all the IDs as
|
||||||
# positive to make life easier for us, and simply negate the IDs when we
|
# positive to make life easier for us, and simply negate the IDs when we
|
||||||
# return them.
|
# return them.
|
||||||
self._current_positions = {} # type: Dict[str, int]
|
self._current_positions: Dict[str, int] = {}
|
||||||
|
|
||||||
# Set of local IDs that we're still processing. The current position
|
# Set of local IDs that we're still processing. The current position
|
||||||
# should be less than the minimum of this set (if not empty).
|
# should be less than the minimum of this set (if not empty).
|
||||||
self._unfinished_ids = set() # type: Set[int]
|
self._unfinished_ids: Set[int] = set()
|
||||||
|
|
||||||
# Set of local IDs that we've processed that are larger than the current
|
# Set of local IDs that we've processed that are larger than the current
|
||||||
# position, due to there being smaller unpersisted IDs.
|
# position, due to there being smaller unpersisted IDs.
|
||||||
self._finished_ids = set() # type: Set[int]
|
self._finished_ids: Set[int] = set()
|
||||||
|
|
||||||
# We track the max position where we know everything before has been
|
# We track the max position where we know everything before has been
|
||||||
# persisted. This is done by a) looking at the min across all instances
|
# persisted. This is done by a) looking at the min across all instances
|
||||||
|
@ -265,7 +265,7 @@ class MultiWriterIdGenerator:
|
||||||
self._persisted_upto_position = (
|
self._persisted_upto_position = (
|
||||||
min(self._current_positions.values()) if self._current_positions else 1
|
min(self._current_positions.values()) if self._current_positions else 1
|
||||||
)
|
)
|
||||||
self._known_persisted_positions = [] # type: List[int]
|
self._known_persisted_positions: List[int] = []
|
||||||
|
|
||||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||||
|
|
||||||
|
@ -465,7 +465,7 @@ class MultiWriterIdGenerator:
|
||||||
self._unfinished_ids.discard(next_id)
|
self._unfinished_ids.discard(next_id)
|
||||||
self._finished_ids.add(next_id)
|
self._finished_ids.add(next_id)
|
||||||
|
|
||||||
new_cur = None # type: Optional[int]
|
new_cur: Optional[int] = None
|
||||||
|
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
# If there are unfinished IDs then the new position will be the
|
# If there are unfinished IDs then the new position will be the
|
||||||
|
|
|
@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
get_next_id_txn; should return the curreent maximum id
|
get_next_id_txn; should return the curreent maximum id
|
||||||
"""
|
"""
|
||||||
# the callback. this is cleared after it is called, so that it can be GCed.
|
# the callback. this is cleared after it is called, so that it can be GCed.
|
||||||
self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
|
self._callback: Optional[GetFirstCallbackType] = get_first_callback
|
||||||
|
|
||||||
# The current max value, or None if we haven't looked in the DB yet.
|
# The current max value, or None if we haven't looked in the DB yet.
|
||||||
self._current_max_id = None # type: Optional[int]
|
self._current_max_id: Optional[int] = None
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def get_next_id_txn(self, txn: Cursor) -> int:
|
def get_next_id_txn(self, txn: Cursor) -> int:
|
||||||
|
@ -274,7 +274,7 @@ def build_sequence_generator(
|
||||||
`check_consistency` details.
|
`check_consistency` details.
|
||||||
"""
|
"""
|
||||||
if isinstance(database_engine, PostgresEngine):
|
if isinstance(database_engine, PostgresEngine):
|
||||||
seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
|
seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
|
||||||
else:
|
else:
|
||||||
seq = LocalSequenceGenerator(get_first_callback)
|
seq = LocalSequenceGenerator(get_first_callback)
|
||||||
|
|
||||||
|
|
|
@ -257,7 +257,7 @@ class Linearizer:
|
||||||
max_count: The maximum number of concurrent accesses
|
max_count: The maximum number of concurrent accesses
|
||||||
"""
|
"""
|
||||||
if name is None:
|
if name is None:
|
||||||
self.name = id(self) # type: Union[str, int]
|
self.name: Union[str, int] = id(self)
|
||||||
else:
|
else:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ class Linearizer:
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
|
|
||||||
# key_to_defer is a map from the key to a _LinearizerEntry.
|
# key_to_defer is a map from the key to a _LinearizerEntry.
|
||||||
self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
|
self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
|
||||||
|
|
||||||
def is_queued(self, key: Hashable) -> bool:
|
def is_queued(self, key: Hashable) -> bool:
|
||||||
"""Checks whether there is a process queued up waiting"""
|
"""Checks whether there is a process queued up waiting"""
|
||||||
|
@ -409,10 +409,10 @@ class ReadWriteLock:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Latest readers queued
|
# Latest readers queued
|
||||||
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
|
self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
|
||||||
|
|
||||||
# Latest writer queued
|
# Latest writer queued
|
||||||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
|
||||||
|
|
||||||
async def read(self, key: str) -> ContextManager:
|
async def read(self, key: str) -> ContextManager:
|
||||||
new_defer = defer.Deferred()
|
new_defer = defer.Deferred()
|
||||||
|
|
|
@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]):
|
||||||
self._clock = clock
|
self._clock = clock
|
||||||
|
|
||||||
# The set of keys currently being processed.
|
# The set of keys currently being processed.
|
||||||
self._processing_keys = set() # type: Set[Hashable]
|
self._processing_keys: Set[Hashable] = set()
|
||||||
|
|
||||||
# The currently pending batch of values by key, with a Deferred to call
|
# The currently pending batch of values by key, with a Deferred to call
|
||||||
# with the result of the corresponding `_process_batch_callback` call.
|
# with the result of the corresponding `_process_batch_callback` call.
|
||||||
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
|
self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
|
||||||
|
|
||||||
# The function to call with batches of values.
|
# The function to call with batches of values.
|
||||||
self._process_batch_callback = process_batch_callback
|
self._process_batch_callback = process_batch_callback
|
||||||
|
@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]):
|
||||||
|
|
||||||
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
|
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
|
||||||
|
|
||||||
self._number_in_flight_metric = number_in_flight.labels(
|
self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
|
||||||
self._name
|
|
||||||
) # type: Gauge
|
|
||||||
|
|
||||||
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
||||||
"""Adds the value to the queue with the given key, returning the result
|
"""Adds the value to the queue with the given key, returning the result
|
||||||
|
|
|
@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
|
||||||
TRACK_MEMORY_USAGE = False
|
TRACK_MEMORY_USAGE = False
|
||||||
|
|
||||||
|
|
||||||
caches_by_name = {} # type: Dict[str, Sized]
|
caches_by_name: Dict[str, Sized] = {}
|
||||||
collectors_by_name = {} # type: Dict[str, CacheMetric]
|
collectors_by_name: Dict[str, "CacheMetric"] = {}
|
||||||
|
|
||||||
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
|
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
|
||||||
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
|
||||||
|
|
|
@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
|
||||||
f: The underlying function. Only one call to this function will be alive
|
f: The underlying function. Only one call to this function will be alive
|
||||||
at once (per instance of CachedCall)
|
at once (per instance of CachedCall)
|
||||||
"""
|
"""
|
||||||
self._callable = f # type: Optional[Callable[[], Awaitable[TV]]]
|
self._callable: Optional[Callable[[], Awaitable[TV]]] = f
|
||||||
self._deferred = None # type: Optional[Deferred]
|
self._deferred: Optional[Deferred] = None
|
||||||
self._result = None # type: Union[None, Failure, TV]
|
self._result: Union[None, Failure, TV] = None
|
||||||
|
|
||||||
async def get(self) -> TV:
|
async def get(self) -> TV:
|
||||||
"""Kick off the call if necessary, and return the result"""
|
"""Kick off the call if necessary, and return the result"""
|
||||||
|
|
|
@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
cache_type = TreeCache if tree else dict
|
cache_type = TreeCache if tree else dict
|
||||||
|
|
||||||
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
|
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
|
||||||
self._pending_deferred_cache = (
|
self._pending_deferred_cache: Union[
|
||||||
cache_type()
|
TreeCache, "MutableMapping[KT, CacheEntry]"
|
||||||
) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
|
] = cache_type()
|
||||||
|
|
||||||
def metrics_cb():
|
def metrics_cb():
|
||||||
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
|
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
|
||||||
|
|
||||||
# cache is used for completed results and maps to the result itself, rather than
|
# cache is used for completed results and maps to the result itself, rather than
|
||||||
# a Deferred.
|
# a Deferred.
|
||||||
self.cache = LruCache(
|
self.cache: LruCache[KT, VT] = LruCache(
|
||||||
max_size=max_entries,
|
max_size=max_entries,
|
||||||
cache_name=name,
|
cache_name=name,
|
||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
size_callback=(lambda d: len(d) or 1) if iterable else None,
|
size_callback=(lambda d: len(d) or 1) if iterable else None,
|
||||||
metrics_collection_callback=metrics_cb,
|
metrics_collection_callback=metrics_cb,
|
||||||
apply_cache_factor_from_config=apply_cache_factor_from_config,
|
apply_cache_factor_from_config=apply_cache_factor_from_config,
|
||||||
) # type: LruCache[KT, VT]
|
)
|
||||||
|
|
||||||
self.thread = None # type: Optional[threading.Thread]
|
self.thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_entries(self):
|
def max_entries(self):
|
||||||
|
|
|
@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
class _CachedFunction(Generic[F]):
|
class _CachedFunction(Generic[F]):
|
||||||
invalidate = None # type: Any
|
invalidate: Any = None
|
||||||
invalidate_all = None # type: Any
|
invalidate_all: Any = None
|
||||||
prefill = None # type: Any
|
prefill: Any = None
|
||||||
cache = None # type: Any
|
cache: Any = None
|
||||||
num_args = None # type: Any
|
num_args: Any = None
|
||||||
|
|
||||||
__name__ = None # type: str
|
__name__: str
|
||||||
|
|
||||||
# Note: This function signature is actually fiddled with by the synapse mypy
|
# Note: This function signature is actually fiddled with by the synapse mypy
|
||||||
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
|
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
|
||||||
__call__ = None # type: F
|
__call__: F
|
||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase:
|
class _CacheDescriptorBase:
|
||||||
|
@ -115,8 +115,8 @@ class _CacheDescriptorBase:
|
||||||
|
|
||||||
|
|
||||||
class _LruCachedFunction(Generic[F]):
|
class _LruCachedFunction(Generic[F]):
|
||||||
cache = None # type: LruCache[CacheKey, Any]
|
cache: LruCache[CacheKey, Any]
|
||||||
__call__ = None # type: F
|
__call__: F
|
||||||
|
|
||||||
|
|
||||||
def lru_cache(
|
def lru_cache(
|
||||||
|
@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
|
|
||||||
def __get__(self, obj, owner):
|
def __get__(self, obj, owner):
|
||||||
cache = LruCache(
|
cache: LruCache[CacheKey, Any] = LruCache(
|
||||||
cache_name=self.orig.__name__,
|
cache_name=self.orig.__name__,
|
||||||
max_size=self.max_entries,
|
max_size=self.max_entries,
|
||||||
) # type: LruCache[CacheKey, Any]
|
)
|
||||||
|
|
||||||
get_cache_key = self.cache_key_builder
|
get_cache_key = self.cache_key_builder
|
||||||
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
||||||
|
@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
def __get__(self, obj, owner):
|
def __get__(self, obj, owner):
|
||||||
cache = DeferredCache(
|
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
max_entries=self.max_entries,
|
max_entries=self.max_entries,
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[CacheKey, Any]
|
)
|
||||||
|
|
||||||
get_cache_key = self.cache_key_builder
|
get_cache_key = self.cache_key_builder
|
||||||
|
|
||||||
|
@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cached_method = getattr(obj, self.cached_method_name)
|
cached_method = getattr(obj, self.cached_method_name)
|
||||||
cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
|
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
||||||
num_args = cached_method.num_args
|
num_args = cached_method.num_args
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
|
@ -472,15 +472,15 @@ class _CacheContext:
|
||||||
|
|
||||||
Cache = Union[DeferredCache, LruCache]
|
Cache = Union[DeferredCache, LruCache]
|
||||||
|
|
||||||
_cache_context_objects = (
|
_cache_context_objects: """WeakValueDictionary[
|
||||||
WeakValueDictionary()
|
Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
|
||||||
) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
|
]""" = WeakValueDictionary()
|
||||||
|
|
||||||
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
|
def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._cache_key = cache_key
|
self._cache_key = cache_key
|
||||||
|
|
||||||
def invalidate(self): # type: () -> None
|
def invalidate(self) -> None:
|
||||||
"""Invalidates the cache entry referred to by the context."""
|
"""Invalidates the cache entry referred to by the context."""
|
||||||
self._cache.invalidate(self._cache_key)
|
self._cache.invalidate(self._cache_key)
|
||||||
|
|
||||||
|
|
|
@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str, max_entries: int = 1000):
|
def __init__(self, name: str, max_entries: int = 1000):
|
||||||
self.cache = LruCache(
|
self.cache: LruCache[KT, DictionaryEntry] = LruCache(
|
||||||
max_size=max_entries, cache_name=name, size_callback=len
|
max_size=max_entries, cache_name=name, size_callback=len
|
||||||
) # type: LruCache[KT, DictionaryEntry]
|
)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sequence = 0
|
self.sequence = 0
|
||||||
self.thread = None # type: Optional[threading.Thread]
|
self.thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
def check_thread(self) -> None:
|
def check_thread(self) -> None:
|
||||||
expected_thread = self.thread
|
expected_thread = self.thread
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SENTINEL = object() # type: Any
|
SENTINEL: Any = object()
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
|
||||||
self._expiry_ms = expiry_ms
|
self._expiry_ms = expiry_ms
|
||||||
self._reset_expiry_on_get = reset_expiry_on_get
|
self._reset_expiry_on_get = reset_expiry_on_get
|
||||||
|
|
||||||
self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]
|
self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
|
||||||
|
|
||||||
self.iterable = iterable
|
self.iterable = iterable
|
||||||
|
|
||||||
|
|
|
@ -226,7 +226,7 @@ class _Node:
|
||||||
# footprint down. Storing `None` is free as its a singleton, while empty
|
# footprint down. Storing `None` is free as its a singleton, while empty
|
||||||
# lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
|
# lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
|
||||||
# thing and used sets).
|
# thing and used sets).
|
||||||
self.callbacks = None # type: Optional[List[Callable[[], None]]]
|
self.callbacks: Optional[List[Callable[[], None]]] = None
|
||||||
|
|
||||||
self.add_callbacks(callbacks)
|
self.add_callbacks(callbacks)
|
||||||
|
|
||||||
|
@ -362,15 +362,15 @@ class LruCache(Generic[KT, VT]):
|
||||||
|
|
||||||
# register_cache might call our "set_cache_factor" callback; there's nothing to
|
# register_cache might call our "set_cache_factor" callback; there's nothing to
|
||||||
# do yet when we get resized.
|
# do yet when we get resized.
|
||||||
self._on_resize = None # type: Optional[Callable[[],None]]
|
self._on_resize: Optional[Callable[[], None]] = None
|
||||||
|
|
||||||
if cache_name is not None:
|
if cache_name is not None:
|
||||||
metrics = register_cache(
|
metrics: Optional[CacheMetric] = register_cache(
|
||||||
"lru_cache",
|
"lru_cache",
|
||||||
cache_name,
|
cache_name,
|
||||||
self,
|
self,
|
||||||
collect_callback=metrics_collection_callback,
|
collect_callback=metrics_collection_callback,
|
||||||
) # type: Optional[CacheMetric]
|
)
|
||||||
else:
|
else:
|
||||||
metrics = None
|
metrics = None
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
|
||||||
# This is poorly-named: it includes both complete and incomplete results.
|
# This is poorly-named: it includes both complete and incomplete results.
|
||||||
# We keep complete results rather than switching to absolute values because
|
# We keep complete results rather than switching to absolute values because
|
||||||
# that makes it easier to cache Failure results.
|
# that makes it easier to cache Failure results.
|
||||||
self.pending_result_cache = {} # type: Dict[KV, ObservableDeferred]
|
self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
|
||||||
|
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.timeout_sec = timeout_ms / 1000.0
|
self.timeout_sec = timeout_ms / 1000.0
|
||||||
|
|
|
@ -45,10 +45,10 @@ class StreamChangeCache:
|
||||||
):
|
):
|
||||||
self._original_max_size = max_size
|
self._original_max_size = max_size
|
||||||
self._max_size = math.floor(max_size)
|
self._max_size = math.floor(max_size)
|
||||||
self._entity_to_key = {} # type: Dict[EntityType, int]
|
self._entity_to_key: Dict[EntityType, int] = {}
|
||||||
|
|
||||||
# map from stream id to the a set of entities which changed at that stream id.
|
# map from stream id to the a set of entities which changed at that stream id.
|
||||||
self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]]
|
self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
|
||||||
|
|
||||||
# the earliest stream_pos for which we can reliably answer
|
# the earliest stream_pos for which we can reliably answer
|
||||||
# get_all_entities_changed. In other words, one less than the earliest
|
# get_all_entities_changed. In other words, one less than the earliest
|
||||||
|
@ -155,7 +155,7 @@ class StreamChangeCache:
|
||||||
if stream_pos < self._earliest_known_stream_pos:
|
if stream_pos < self._earliest_known_stream_pos:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
changed_entities = [] # type: List[EntityType]
|
changed_entities: List[EntityType] = []
|
||||||
|
|
||||||
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
|
||||||
changed_entities.extend(self._cache[k])
|
changed_entities.extend(self._cache[k])
|
||||||
|
|
|
@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SENTINEL = object() # type: Any
|
SENTINEL: Any = object()
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
KT = TypeVar("KT")
|
KT = TypeVar("KT")
|
||||||
|
@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
|
||||||
|
|
||||||
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
|
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
|
||||||
# map from key to _CacheEntry
|
# map from key to _CacheEntry
|
||||||
self._data = {} # type: Dict[KT, _CacheEntry]
|
self._data: Dict[KT, _CacheEntry] = {}
|
||||||
|
|
||||||
# the _CacheEntries, sorted by expiry time
|
# the _CacheEntries, sorted by expiry time
|
||||||
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
|
self._expiry_list: SortedList[_CacheEntry] = SortedList()
|
||||||
|
|
||||||
self._timer = timer
|
self._timer = timer
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ def sorted_topologically(
|
||||||
# This is implemented by Kahn's algorithm.
|
# This is implemented by Kahn's algorithm.
|
||||||
|
|
||||||
degree_map = {node: 0 for node in nodes}
|
degree_map = {node: 0 for node in nodes}
|
||||||
reverse_graph = {} # type: Dict[T, Set[T]]
|
reverse_graph: Dict[T, Set[T]] = {}
|
||||||
|
|
||||||
for node, edges in graph.items():
|
for node, edges in graph.items():
|
||||||
if node not in degree_map:
|
if node not in degree_map:
|
||||||
|
|
|
@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||||
caveat in the macaroon, or if the caveat was not found in the macaroon.
|
caveat in the macaroon, or if the caveat was not found in the macaroon.
|
||||||
"""
|
"""
|
||||||
prefix = key + " = "
|
prefix = key + " = "
|
||||||
result = None # type: Optional[str]
|
result: Optional[str] = None
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
if not caveat.caveat_id.startswith(prefix):
|
if not caveat.caveat_id.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -124,7 +124,7 @@ class Measure:
|
||||||
assert isinstance(curr_context, LoggingContext)
|
assert isinstance(curr_context, LoggingContext)
|
||||||
parent_context = curr_context
|
parent_context = curr_context
|
||||||
self._logging_context = LoggingContext(str(curr_context), parent_context)
|
self._logging_context = LoggingContext(str(curr_context), parent_context)
|
||||||
self.start = None # type: Optional[int]
|
self.start: Optional[int] = None
|
||||||
|
|
||||||
def __enter__(self) -> "Measure":
|
def __enter__(self) -> "Measure":
|
||||||
if self.start is not None:
|
if self.start is not None:
|
||||||
|
|
|
@ -41,7 +41,7 @@ def do_patch():
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
start_context = current_context()
|
start_context = current_context()
|
||||||
changes = [] # type: List[str]
|
changes: List[str] = []
|
||||||
orig = orig_inline_callbacks(_check_yield_points(f, changes))
|
orig = orig_inline_callbacks(_check_yield_points(f, changes))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
|
||||||
gen = f(*args, **kwargs)
|
gen = f(*args, **kwargs)
|
||||||
|
|
||||||
last_yield_line_no = gen.gi_frame.f_lineno
|
last_yield_line_no = gen.gi_frame.f_lineno
|
||||||
result = None # type: Any
|
result: Any = None
|
||||||
while True:
|
while True:
|
||||||
expected_context = current_context()
|
expected_context = current_context()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue