Add missing type hints to event fetching. (#11121)

Updates the event rows returned from the database to be
attrs classes instead of dictionaries.
This commit is contained in:
Patrick Cloke 2021-10-19 10:29:03 -04:00 committed by GitHub
parent 5e0e683541
commit 0dd0c40329
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 61 deletions

1
changelog.d/11121.misc Normal file
View file

@ -0,0 +1 @@
Add type hints for event fetching.

View file

@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
@ -86,6 +87,47 @@ class _EventCacheEntry:
redacted_event: Optional[EventBase] redacted_event: Optional[EventBase]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventRow:
"""
An event, as pulled from the database.
Properties:
event_id: The event ID of the event.
stream_ordering: stream ordering for this event
json: json-encoded event structure
internal_metadata: json-encoded internal metadata dict
format_version: The format of the event. Hopefully one of EventFormatVersions.
'None' means the event predates EventFormatVersions (so the event is format V1).
room_version_id: The version of the room which contains the event. Hopefully
one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
rejected_reason: if the event was rejected, the reason why.
redactions: a list of event-ids which (claim to) redact this event.
outlier: True if this event is an outlier.
"""
event_id: str
stream_ordering: int
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
class EventRedactBehaviour(Names): class EventRedactBehaviour(Names):
""" """
What to do when retrieving a redacted event from the database. What to do when retrieving a redacted event from the database.
@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values() for e in state_to_include.values()
] ]
def _do_fetch(self, conn): def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from """Takes a database connection and waits for requests for events from
the _event_fetch_list queue. the _event_fetch_list queue.
""" """
@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list) self._fetch_event_list(conn, event_list)
def _fetch_event_list(self, conn, event_list): def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
) -> None:
"""Handle a load of requests from the _event_fetch_list queue """Handle a load of requests from the _event_fetch_list queue
Args: Args:
conn (twisted.enterprise.adbapi.Connection): database connection conn: database connection
event_list (list[Tuple[list[str], Deferred]]): event_list:
The fetch requests. Each entry consists of a list of event The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the ids to be fetched, and a deferred to be completed once the
events have been fetched. events have been fetched.
@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
row = row_map.get(event_id) row = row_map.get(event_id)
fetched_events[event_id] = row fetched_events[event_id] = row
if row: if row:
redaction_ids.update(row["redactions"]) redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys()) events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch: if events_to_fetch:
@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
for event_id, row in fetched_events.items(): for event_id, row in fetched_events.items():
if not row: if not row:
continue continue
assert row["event_id"] == event_id assert row.event_id == event_id
rejected_reason = row["rejected_reason"] rejected_reason = row.rejected_reason
# If the event or metadata cannot be parsed, log the error and act # If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown. # as if the event is unknown.
try: try:
d = db_to_json(row["json"]) d = db_to_json(row.json)
except ValueError: except ValueError:
logger.error("Unable to parse json from event: %s", event_id) logger.error("Unable to parse json from event: %s", event_id)
continue continue
try: try:
internal_metadata = db_to_json(row["internal_metadata"]) internal_metadata = db_to_json(row.internal_metadata)
except ValueError: except ValueError:
logger.error( logger.error(
"Unable to parse internal_metadata from event: %s", event_id "Unable to parse internal_metadata from event: %s", event_id
) )
continue continue
format_version = row["format_version"] format_version = row.format_version
if format_version is None: if format_version is None:
# This means that we stored the event before we had the concept # This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event. # of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1 format_version = EventFormatVersions.V1
room_version_id = row["room_version_id"] room_version_id = row.room_version_id
if not room_version_id: if not room_version_id:
# this should only happen for out-of-band membership events which # this should only happen for out-of-band membership events which
@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
internal_metadata_dict=internal_metadata, internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason, rejected_reason=rejected_reason,
) )
original_ev.internal_metadata.stream_ordering = row["stream_ordering"] original_ev.internal_metadata.stream_ordering = row.stream_ordering
original_ev.internal_metadata.outlier = row["outlier"] original_ev.internal_metadata.outlier = row.outlier
event_map[event_id] = original_ev event_map[event_id] = original_ev
@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
# the cache entries. # the cache entries.
result_map = {} result_map = {}
for event_id, original_ev in event_map.items(): for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"] redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row( redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map original_ev, redactions, event_map
) )
@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
return result_map return result_map
async def _enqueue_events(self, events): async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This """Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events. without having to create a new transaction for each request for events.
Args: Args:
events (Iterable[str]): events to be fetched. events: events to be fetched.
Returns: Returns:
Dict[str, Dict]: map from event id to row data from the database. A map from event id to row data from the database. May contain events
May contain events that weren't requested. that weren't requested.
""" """
events_d = defer.Deferred() events_d = defer.Deferred()
@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
return row_map return row_map
def _fetch_event_rows(self, txn, event_ids): def _fetch_event_rows(
self, txn: LoggingTransaction, event_ids: Iterable[str]
) -> Dict[str, _EventRow]:
"""Fetch event rows from the database """Fetch event rows from the database
Events which are not found are omitted from the result. Events which are not found are omitted from the result.
The returned per-event dicts contain the following keys:
* event_id (str)
* stream_ordering (int): stream ordering for this event
* json (str): json-encoded event structure
* internal_metadata (str): json-encoded internal metadata dict
* format_version (int|None): The format of the event. Hopefully one
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
* room_version_id (str|None): The version of the room which contains the event.
Hopefully one of RoomVersions.
Due to historical reasons, there may be a few events in the database which
do not have an associated room; in this case None will be returned here.
* rejected_reason (str|None): if the event was rejected, the reason
why.
* redactions (List[str]): a list of event-ids which (claim to) redact
this event.
Args: Args:
txn (twisted.enterprise.adbapi.Connection): txn: The database transaction.
event_ids (Iterable[str]): event IDs to fetch event_ids: event IDs to fetch
Returns: Returns:
Dict[str, Dict]: a map from event id to event info. A map from event id to event info.
""" """
event_dict = {} event_dict = {}
for evs in batch_iter(event_ids, 200): for evs in batch_iter(event_ids, 200):
@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
for row in txn: for row in txn:
event_id = row[0] event_id = row[0]
event_dict[event_id] = { event_dict[event_id] = _EventRow(
"event_id": event_id, event_id=event_id,
"stream_ordering": row[1], stream_ordering=row[1],
"internal_metadata": row[2], internal_metadata=row[2],
"json": row[3], json=row[3],
"format_version": row[4], format_version=row[4],
"room_version_id": row[5], room_version_id=row[5],
"rejected_reason": row[6], rejected_reason=row[6],
"redactions": [], redactions=[],
"outlier": row[7], outlier=row[7],
} )
# check for redactions # check for redactions
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
for (redacter, redacted) in txn: for (redacter, redacted) in txn:
d = event_dict.get(redacted) d = event_dict.get(redacted)
if d: if d:
d["redactions"].append(redacter) d.redactions.append(redacter)
return event_dict return event_dict