Make all process_replication_rows methods async (#13304)

More prep work for asyncronous caching, also makes all process_replication_rows methods consistent (presence handler already is so).

Signed off by Nick @ Beeper (@Fizzadar)
This commit is contained in:
Nick Mills-Barrett 2022-07-17 23:19:43 +02:00 committed by GitHub
parent 96cf81e312
commit 5d4028f217
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 40 additions and 25 deletions

1
changelog.d/13304.misc Normal file
View file

@ -0,0 +1 @@
Make all replication row processing methods asynchronous. Contributed by Nick @ Beeper (@fizzadar).

View file

@ -158,7 +158,7 @@ class FollowerTypingHandler:
except Exception: except Exception:
logger.exception("Error pushing typing notif to remotes") logger.exception("Error pushing typing notif to remotes")
def process_replication_rows( async def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
"""Should be called whenever we receive updates for typing stream.""" """Should be called whenever we receive updates for typing stream."""
@ -444,7 +444,7 @@ class TypingWriterHandler(FollowerTypingHandler):
return rows, current_id, limited return rows, current_id, limited
def process_replication_rows( async def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
# The writing process should never get updates from replication. # The writing process should never get updates from replication.

View file

@ -49,7 +49,7 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
@ -59,7 +59,9 @@ class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
self._device_list_id_gen.advance(instance_name, token) self._device_list_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def _invalidate_caches_for_devices( def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]

View file

@ -24,7 +24,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self) -> int: def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushRulesStream.NAME: if stream_name == PushRulesStream.NAME:
@ -33,4 +33,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token) self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

View file

@ -40,9 +40,11 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

View file

@ -144,13 +144,15 @@ class ReplicationDataHandler:
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, instance_name, token, rows) await self.store.process_replication_rows(
stream_name, instance_name, token, rows
)
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows) await self.send_handler.process_replication_rows(stream_name, token, rows)
if stream_name == TypingStream.NAME: if stream_name == TypingStream.NAME:
self._typing_handler.process_replication_rows(token, rows) await self._typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows] StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
) )

View file

@ -47,7 +47,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,

View file

@ -414,7 +414,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -437,7 +437,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
) )
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict

View file

@ -119,7 +119,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
@ -154,7 +154,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys) self._attempt_to_invalidate_cache(row.cache_func, row.keys)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data

View file

@ -128,7 +128,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -148,7 +148,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._device_federation_outbox_stream_cache.entity_has_changed( self._device_federation_outbox_stream_cache.entity_has_changed(
row.entity, token row.entity, token
) )
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def get_to_device_stream_token(self) -> int: def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()

View file

@ -280,7 +280,7 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id", id_column="chain_id",
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -292,7 +292,7 @@ class EventsWorkerStore(SQLBaseStore):
elif stream_name == BackfillStream.NAME: elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token) self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
async def have_censored_event(self, event_id: str) -> bool: async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased """Check if an event has been censored, i.e. if the content of the event has been erased

View file

@ -431,7 +431,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_on_startup = [] self._presence_on_startup = []
return active_on_startup return active_on_startup
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -443,4 +443,6 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
for row in rows: for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token) self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)

View file

@ -589,7 +589,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_unread_event_push_actions_by_room_for_user", (room_id,) "get_unread_event_push_actions_by_room_for_user", (room_id,)
) )
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -604,7 +604,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
self._receipts_stream_cache.entity_has_changed(row.room_id, token) self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return await super().process_replication_rows(
stream_name, instance_name, token, rows
)
def _insert_linearized_receipt_txn( def _insert_linearized_receipt_txn(
self, self,

View file

@ -292,7 +292,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
# than the id that the client has. # than the id that the client has.
pass pass
def process_replication_rows( async def process_replication_rows(
self, self,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -305,7 +305,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
self.get_tags_for_user.invalidate((row.user_id,)) self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):