mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 20:03:54 +01:00
Add some type hints to event_federation
datastore (#12753)
Co-authored-by: David Robertson <david.m.robertson1@gmail.com>
This commit is contained in:
parent
682431efbe
commit
50ae4eafe1
5 changed files with 127 additions and 65 deletions
1
changelog.d/12753.misc
Normal file
1
changelog.d/12753.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add some type hints to datastore.
|
1
mypy.ini
1
mypy.ini
|
@ -27,7 +27,6 @@ exclude = (?x)
|
|||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|synapse/storage/databases/main/event_federation.py
|
||||
|synapse/storage/schema/
|
||||
|
||||
|tests/api/test_auth.py
|
||||
|
|
|
@ -53,6 +53,7 @@ class RoomBatchHandler:
|
|||
# We want to use the successor event depth so they appear after `prev_event` because
|
||||
# it has a larger `depth` but before the successor event because the `stream_ordering`
|
||||
# is negative before the successor event.
|
||||
assert most_recent_prev_event_id is not None
|
||||
successor_event_ids = await self.store.get_successor_events(
|
||||
most_recent_prev_event_id
|
||||
)
|
||||
|
@ -139,6 +140,7 @@ class RoomBatchHandler:
|
|||
_,
|
||||
) = await self.store.get_max_depth_of(event_ids)
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_event_id is not None
|
||||
prev_state_map = await self.state_store.get_state_ids_for_event(
|
||||
most_recent_event_id
|
||||
)
|
||||
|
|
|
@ -14,7 +14,17 @@
|
|||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
@ -33,7 +43,7 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id)
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
def _get_auth_chain_ids_using_cover_index_txn(
|
||||
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
include_given: bool,
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain IDs using the chain index."""
|
||||
|
||||
|
@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
chains: Dict[int, int] = {}
|
||||
|
||||
# Add all linked chains reachable from initial set of chains.
|
||||
for batch in batch_iter(event_chains, 1000):
|
||||
for batch2 in batch_iter(event_chains, 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
txn.database_engine, "origin_chain_id", batch2
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
|
@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
new_front = set()
|
||||
new_front: Set[str] = set()
|
||||
for chunk in batch_iter(front, 100):
|
||||
# Pull the auth events either from the cache or DB.
|
||||
to_fetch: List[str] = [] # Event IDs to fetch from DB
|
||||
|
@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
# Note we need to batch up the results by event ID before
|
||||
# adding to the cache.
|
||||
to_cache = {}
|
||||
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||
for event_id, auth_event_id, auth_event_depth in txn:
|
||||
to_cache.setdefault(event_id, []).append(
|
||||
(auth_event_id, auth_event_depth)
|
||||
|
@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id)
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
def _get_auth_chain_difference_using_cover_index_txn(
|
||||
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
|
||||
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using the chain index.
|
||||
|
||||
|
@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
||||
# the loop)
|
||||
for batch in batch_iter(set(seen_chains), 1000):
|
||||
for batch2 in batch_iter(set(seen_chains), 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
txn.database_engine, "origin_chain_id", batch2
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
|
@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
return result
|
||||
|
||||
def _get_auth_chain_difference_txn(
|
||||
self, txn, state_sets: List[Set[str]]
|
||||
self, txn: LoggingTransaction, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using a breadth first search.
|
||||
|
||||
|
@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
# I think building a temporary list with fetchall is more efficient than
|
||||
# just `search.extend(txn)`, but this is unconfirmed
|
||||
search.extend(txn.fetchall())
|
||||
search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
|
||||
|
||||
# sort by depth
|
||||
search.sort()
|
||||
|
@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# We parse the results and add the to the `found` set and the
|
||||
# cache (note we need to batch up the results by event ID before
|
||||
# adding to the cache).
|
||||
to_cache = {}
|
||||
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||
for event_id, auth_event_id, auth_event_depth in txn:
|
||||
to_cache.setdefault(event_id, []).append(
|
||||
(auth_event_id, auth_event_depth)
|
||||
|
@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
async def get_oldest_event_ids_with_depth_in_room(
|
||||
self, room_id
|
||||
self, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""Gets the oldest events(backwards extremities) in the room along with the
|
||||
aproximate depth.
|
||||
|
@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
List of (event_id, depth) tuples
|
||||
"""
|
||||
|
||||
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
|
||||
def get_oldest_event_ids_with_depth_in_room_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
# Assemble a dictionary with event_id -> depth for the oldest events
|
||||
# we know of in the room. Backwards extremeties are the oldest
|
||||
# events we know of in the room but we only know of them because
|
||||
|
@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
txn.execute(sql, (room_id, False))
|
||||
|
||||
return txn.fetchall()
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_oldest_event_ids_with_depth_in_room",
|
||||
|
@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
async def get_insertion_event_backward_extremities_in_room(
|
||||
self, room_id
|
||||
self, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""Get the insertion events we know about that we haven't backfilled yet.
|
||||
|
||||
|
@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
List of (event_id, depth) tuples
|
||||
"""
|
||||
|
||||
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
|
||||
def get_insertion_event_backward_extremities_in_room_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
sql = """
|
||||
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
|
||||
/* We only want insertion events that are also marked as backwards extremities */
|
||||
|
@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
return txn.fetchall()
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_insertion_event_backward_extremities_in_room",
|
||||
|
@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
room_id,
|
||||
)
|
||||
|
||||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
||||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
|
||||
|
||||
Args:
|
||||
|
@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
return max_depth_event_id, current_max_depth
|
||||
|
||||
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
||||
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
|
||||
|
||||
Args:
|
||||
|
@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
|
||||
)
|
||||
|
||||
def _get_prev_events_for_room_txn(self, txn, room_id: str):
|
||||
def _get_prev_events_for_room_txn(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
) -> List[str]:
|
||||
# we just use the 10 newest events. Older events will become
|
||||
# prev_events of future events.
|
||||
|
||||
|
@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
sorted by extremity count.
|
||||
"""
|
||||
|
||||
def _get_rooms_with_many_extremities_txn(txn):
|
||||
def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
|
||||
where_clause = "1=1"
|
||||
if room_id_filter:
|
||||
where_clause = "room_id NOT IN (%s)" % (
|
||||
|
@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
)
|
||||
|
||||
def _get_min_depth_interaction(self, txn, room_id):
|
||||
def _get_min_depth_interaction(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
) -> Optional[int]:
|
||||
min_depth = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
|
@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"""
|
||||
# We want to make the cache more effective, so we clamp to the last
|
||||
# change before the given ordering.
|
||||
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
|
||||
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
|
||||
|
||||
# We don't always have a full stream_to_exterm_id table, e.g. after
|
||||
# the upgrade that introduced it, so we make sure we never ask for a
|
||||
# stream_ordering from before a restart
|
||||
last_change = max(self._stream_order_on_start, last_change)
|
||||
last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
|
||||
|
||||
# provided the last_change is recent enough, we now clamp the requested
|
||||
# stream_ordering to it.
|
||||
if last_change > self.stream_ordering_month_ago:
|
||||
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||
stream_ordering = min(last_change, stream_ordering)
|
||||
|
||||
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
||||
|
||||
@cached(max_entries=5000, num_args=2)
|
||||
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||
async def _get_forward_extremeties_for_room(
|
||||
self, room_id: str, stream_ordering: int
|
||||
) -> List[str]:
|
||||
"""For a given room_id and stream_ordering, return the forward
|
||||
extremeties of the room at that point in "time".
|
||||
|
||||
|
@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
stream_orderings from that point.
|
||||
"""
|
||||
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||
|
||||
sql = """
|
||||
|
@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
WHERE room_id = ?
|
||||
"""
|
||||
|
||||
def get_forward_extremeties_for_room_txn(txn):
|
||||
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
|
||||
txn.execute(sql, (stream_ordering, room_id))
|
||||
return [event_id for event_id, in txn]
|
||||
|
||||
|
@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
]
|
||||
|
||||
async def get_backfill_events(
|
||||
self, room_id: str, seed_event_id_list: list, limit: int
|
||||
):
|
||||
self, room_id: str, seed_event_id_list: List[str], limit: int
|
||||
) -> List[EventBase]:
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
including) the events in seed_event_id_list. Return a list of max size `limit`
|
||||
|
||||
|
@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
return sorted(
|
||||
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
|
||||
# type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
|
||||
# But it's never None, because these events were previously persisted to the DB.
|
||||
events,
|
||||
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
|
||||
)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
|
||||
def _get_backfill_events(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
seed_event_id_list: List[str],
|
||||
limit: int,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
We want to make sure that we do a breadth-first, "depth" ordered search.
|
||||
We also handle navigating historical branches of history connected by
|
||||
|
@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
limit,
|
||||
)
|
||||
|
||||
event_id_results = set()
|
||||
event_id_results: Set[str] = set()
|
||||
|
||||
# In a PriorityQueue, the lowest valued entries are retrieved first.
|
||||
# We're using depth as the priority in the queue and tie-break based on
|
||||
|
@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
# highest and newest-in-time message. We add events to the queue with a
|
||||
# negative depth so that we process the newest-in-time messages first
|
||||
# going backwards in time. stream_ordering follows the same pattern.
|
||||
queue = PriorityQueue()
|
||||
queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
|
||||
|
||||
for seed_event_id in seed_event_id_list:
|
||||
event_lookup_result = self.db_pool.simple_select_one_txn(
|
||||
|
@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
return event_id_results
|
||||
|
||||
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
||||
async def get_missing_events(
|
||||
self,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[EventBase]:
|
||||
ids = await self.db_pool.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
|
@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
return await self.get_events_as_list(ids)
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
def _get_missing_events(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[str]:
|
||||
|
||||
seen_events = set(earliest_events)
|
||||
front = set(latest_events) - seen_events
|
||||
event_results = []
|
||||
event_results: List[str] = []
|
||||
|
||||
query = (
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
|
@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
@wrap_as_background_process("delete_old_forward_extrem_cache")
|
||||
async def _delete_old_forward_extrem_cache(self) -> None:
|
||||
def _delete_old_forward_extrem_cache_txn(txn):
|
||||
def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
|
||||
# Delete entries older than a month, while making sure we don't delete
|
||||
# the only entries for a room.
|
||||
sql = """
|
||||
|
@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
) AND stream_ordering < ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
|
||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"""
|
||||
if self.db_pool.engine.supports_returning:
|
||||
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
def _remove_received_event_from_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[int]:
|
||||
sql = """
|
||||
DELETE FROM federation_inbound_events_staging
|
||||
WHERE origin = ? AND event_id = ?
|
||||
|
@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"""
|
||||
|
||||
txn.execute(sql, (origin, event_id))
|
||||
return txn.fetchone()
|
||||
row = cast(Optional[Tuple[int]], txn.fetchone())
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"remove_received_event_from_staging",
|
||||
_remove_received_event_from_staging_txn,
|
||||
db_autocommit=True,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row[0]
|
||||
|
||||
else:
|
||||
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
def _remove_received_event_from_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[int]:
|
||||
received_ts = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="federation_inbound_events_staging",
|
||||
|
@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
) -> Optional[Tuple[str, str]]:
|
||||
"""Get the next event ID in the staging area for the given room."""
|
||||
|
||||
def _get_next_staged_event_id_for_room_txn(txn):
|
||||
def _get_next_staged_event_id_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
sql = """
|
||||
SELECT origin, event_id
|
||||
FROM federation_inbound_events_staging
|
||||
|
@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
|
||||
txn.execute(sql, (room_id,))
|
||||
|
||||
return txn.fetchone()
|
||||
return cast(Optional[Tuple[str, str]], txn.fetchone())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
|
||||
|
@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
) -> Optional[Tuple[str, EventBase]]:
|
||||
"""Get the next event in the staging area for the given room."""
|
||||
|
||||
def _get_next_staged_event_for_room_txn(txn):
|
||||
def _get_next_staged_event_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Tuple[str, str, str]]:
|
||||
sql = """
|
||||
SELECT event_json, internal_metadata, origin
|
||||
FROM federation_inbound_events_staging
|
||||
|
@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
|
||||
return txn.fetchone()
|
||||
return cast(Optional[Tuple[str, str, str]], txn.fetchone())
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
|
||||
|
@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
@wrap_as_background_process("_get_stats_for_federation_staging")
|
||||
async def _get_stats_for_federation_staging(self):
|
||||
async def _get_stats_for_federation_staging(self) -> None:
|
||||
"""Update the prometheus metrics for the inbound federation staging area."""
|
||||
|
||||
def _get_stats_for_federation_staging_txn(txn):
|
||||
def _get_stats_for_federation_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, int]:
|
||||
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
|
||||
(count,) = txn.fetchone()
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
|
||||
txn.execute(
|
||||
"SELECT min(received_ts) FROM federation_inbound_events_staging"
|
||||
)
|
||||
|
||||
(received_ts,) = txn.fetchone()
|
||||
(received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
|
||||
|
||||
# If there is nothing in the staging area default it to 0.
|
||||
age = 0
|
||||
|
@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
|
|||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
||||
)
|
||||
|
||||
async def clean_room_for_join(self, room_id):
|
||||
return await self.db_pool.runInteraction(
|
||||
async def clean_room_for_join(self, room_id: str) -> None:
|
||||
await self.db_pool.runInteraction(
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn, room_id):
|
||||
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
||||
async def _background_delete_non_state_event_auth(self, progress, batch_size):
|
||||
def delete_event_auth(txn):
|
||||
async def _background_delete_non_state_event_auth(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
def delete_event_auth(txn: LoggingTransaction) -> bool:
|
||||
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
||||
max_stream_id = progress.get("max_stream_id_exclusive")
|
||||
|
||||
|
|
|
@ -332,6 +332,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
|||
most_recent_prev_event_depth,
|
||||
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_prev_event_id is not None
|
||||
prev_state_map = self.get_success(
|
||||
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue