Add some type hints to event_federation datastore (#12753)

Co-authored-by: David Robertson <david.m.robertson1@gmail.com>
This commit is contained in:
Dirk Klimpel 2022-05-18 17:02:10 +02:00 committed by GitHub
parent 682431efbe
commit 50ae4eafe1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 127 additions and 65 deletions

1
changelog.d/12753.misc Normal file
View file

@ -0,0 +1 @@
Add some type hints to datastore.

View file

@ -27,7 +27,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/schema/
|tests/api/test_auth.py

View file

@ -53,6 +53,7 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id
)
@ -139,6 +140,7 @@ class RoomBatchHandler:
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_event_id
)

View file

@ -14,7 +14,17 @@
import itertools
import logging
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
import attr
from prometheus_client import Counter, Gauge
@ -33,7 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
self,
txn: LoggingTransaction,
room_id: str,
event_ids: Collection[str],
include_given: bool,
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids)
while front:
new_front = set()
new_front: Set[str] = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB
@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before
# adding to the cache.
to_cache = {}
to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch in batch_iter(set(seen_chains), 1000):
for batch2 in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
search.extend(txn.fetchall())
search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
# sort by depth
search.sort()
@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
to_cache = {}
to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(
self, room_id
self, room_id: str
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the
aproximate depth.
@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
def get_oldest_event_ids_with_depth_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
# Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because
@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False))
return txn.fetchall()
return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room",
@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
async def get_insertion_event_backward_extremities_in_room(
self, room_id
self, room_id: str
) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet.
@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
def get_insertion_event_backward_extremities_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
return txn.fetchall()
return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room",
@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args:
@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
def _get_prev_events_for_room_txn(self, txn, room_id: str):
def _get_prev_events_for_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> List[str]:
# we just use the 10 newest events. Older events will become
# prev_events of future events.
@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn):
def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
def _get_min_depth_interaction(
self, txn: LoggingTransaction, room_id: str
) -> Optional[int]:
min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change)
last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
if last_change > self.stream_ordering_month_ago:
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago:
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ?
"""
def get_forward_extremeties_for_room_txn(txn):
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
]
async def get_backfill_events(
self, room_id: str, seed_event_id_list: list, limit: int
):
self, room_id: str, seed_event_id_list: List[str], limit: int
) -> List[EventBase]:
"""Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit`
@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
events = await self.get_events_as_list(event_ids)
return sorted(
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
# type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
# But it's never None, because these events were previously persisted to the DB.
events,
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
)
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
def _get_backfill_events(
self,
txn: LoggingTransaction,
room_id: str,
seed_event_id_list: List[str],
limit: int,
) -> Set[str]:
"""
We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by
@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit,
)
event_id_results = set()
event_id_results: Set[str] = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on
@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern.
queue = PriorityQueue()
queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn(
@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
async def get_missing_events(
self,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
def _get_missing_events(
self,
txn: LoggingTransaction,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[str]:
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results = []
event_results: List[str] = []
query = (
"SELECT prev_event_id FROM event_edges "
@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None:
def _delete_old_forward_extrem_cache_txn(txn):
def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = """
@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ?
"""
txn.execute(
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
)
await self.db_pool.runInteraction(
@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
if self.db_pool.engine.supports_returning:
def _remove_received_event_from_staging_txn(txn):
def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
sql = """
DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ?
@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (origin, event_id))
return txn.fetchone()
row = cast(Optional[Tuple[int]], txn.fetchone())
row = await self.db_pool.runInteraction(
if row is None:
return None
return row[0]
return await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
if row is None:
return None
return row[0]
else:
def _remove_received_event_from_staging_txn(txn):
def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
received_ts = self.db_pool.simple_select_one_onecol_txn(
txn,
table="federation_inbound_events_staging",
@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room."""
def _get_next_staged_event_id_for_room_txn(txn):
def _get_next_staged_event_id_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str]]:
sql = """
SELECT origin, event_id
FROM federation_inbound_events_staging
@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,))
return txn.fetchone()
return cast(Optional[Tuple[str, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room."""
def _get_next_staged_event_for_room_txn(txn):
def _get_next_staged_event_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str, str]]:
sql = """
SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging
@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
return txn.fetchone()
return cast(Optional[Tuple[str, str, str]], txn.fetchone())
row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@wrap_as_background_process("_get_stats_for_federation_staging")
async def _get_stats_for_federation_staging(self):
async def _get_stats_for_federation_staging(self) -> None:
"""Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(txn):
def _get_stats_for_federation_staging_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging"
)
(received_ts,) = txn.fetchone()
(received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
# If there is nothing in the staging area default it to 0.
age = 0
@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
async def clean_room_for_join(self, room_id: str) -> None:
await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
def _clean_room_for_join_txn(self, txn, room_id):
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
async def _background_delete_non_state_event_auth(
self, progress: JsonDict, batch_size: int
) -> int:
def delete_event_auth(txn: LoggingTransaction) -> bool:
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")

View file

@ -332,6 +332,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
)