Add some missing type hints to cache datastore. (#12216)

This commit is contained in:
Patrick Cloke 2022-03-16 10:37:04 -04:00 committed by GitHub
parent 86965605a4
commit c486fa5fd9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 21 deletions

1
changelog.d/12216.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints for cache storage.

View file

@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream, EventsStream,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_caches_txn(txn): def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to # We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache # send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine. # invalidations are idempotent, so duplicates are fine.
@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
for row in rows: for row in rows:
self._process_event_stream_row(token, row) self._process_event_stream_row(token, row)
@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row): def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data
if row.type == EventsStreamEventRow.TypeId: if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event( self._invalidate_caches_for_event(
token, token,
data.event_id, data.event_id,
@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False, backfilled=False,
) )
elif row.type == EventsStreamCurrentStateRow.TypeId: elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed( assert isinstance(data, EventsStreamCurrentStateRow)
row.data.room_id, token self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
)
if data.type == EventTypes.Member: if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate( self.get_rooms_for_user_with_stream_ordering.invalidate(
@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event( def _invalidate_caches_for_event(
self, self,
stream_ordering, stream_ordering: int,
event_id, event_id: str,
room_id, room_id: str,
etype, etype: str,
state_key, state_key: Optional[str],
redacts, redacts: Optional[str],
relates_to, relates_to: Optional[str],
backfilled, backfilled: bool,
): ) -> None:
self._invalidate_get_event_cache(event_id) self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id)) self.have_seen_event.invalidate((room_id, event_id))
@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,)) self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys, keys,
) )
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func): def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves """Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
""" """
@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
) )
def _send_invalidation_to_replication( def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]] self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
): ) -> None:
"""Notifies replication that given cache has been invalidated. """Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally. Note that this does *not* invalidate the cache locally.
@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name, "instance_name": self._instance_name,
"cache_func": cache_name, "cache_func": cache_name,
"keys": keys, "keys": keys,
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self._clock.time_msec(),
}, },
) )