mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 11:53:49 +01:00
Add some type hints to event_federation
datastore (#12753)
Co-authored-by: David Robertson <david.m.robertson1@gmail.com>
This commit is contained in:
parent
682431efbe
commit
50ae4eafe1
5 changed files with 127 additions and 65 deletions
1
changelog.d/12753.misc
Normal file
1
changelog.d/12753.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add some type hints to datastore.
|
1
mypy.ini
1
mypy.ini
|
@ -27,7 +27,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/event_federation.py
|
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/api/test_auth.py
|
|tests/api/test_auth.py
|
||||||
|
|
|
@ -53,6 +53,7 @@ class RoomBatchHandler:
|
||||||
# We want to use the successor event depth so they appear after `prev_event` because
|
# We want to use the successor event depth so they appear after `prev_event` because
|
||||||
# it has a larger `depth` but before the successor event because the `stream_ordering`
|
# it has a larger `depth` but before the successor event because the `stream_ordering`
|
||||||
# is negative before the successor event.
|
# is negative before the successor event.
|
||||||
|
assert most_recent_prev_event_id is not None
|
||||||
successor_event_ids = await self.store.get_successor_events(
|
successor_event_ids = await self.store.get_successor_events(
|
||||||
most_recent_prev_event_id
|
most_recent_prev_event_id
|
||||||
)
|
)
|
||||||
|
@ -139,6 +140,7 @@ class RoomBatchHandler:
|
||||||
_,
|
_,
|
||||||
) = await self.store.get_max_depth_of(event_ids)
|
) = await self.store.get_max_depth_of(event_ids)
|
||||||
# mapping from (type, state_key) -> state_event_id
|
# mapping from (type, state_key) -> state_event_id
|
||||||
|
assert most_recent_event_id is not None
|
||||||
prev_state_map = await self.state_store.get_state_ids_for_event(
|
prev_state_map = await self.state_store.get_state_ids_for_event(
|
||||||
most_recent_event_id
|
most_recent_event_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,17 @@
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
@ -33,7 +43,7 @@ from synapse.storage.database import (
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# Check if we have indexed the room so we can use the chain cover
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
# algorithm.
|
# algorithm.
|
||||||
room = await self.get_room(room_id)
|
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||||
if room["has_auth_chain_index"]:
|
if room["has_auth_chain_index"]:
|
||||||
try:
|
try:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_ids_using_cover_index_txn(
|
def _get_auth_chain_ids_using_cover_index_txn(
|
||||||
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
room_id: str,
|
||||||
|
event_ids: Collection[str],
|
||||||
|
include_given: bool,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Calculates the auth chain IDs using the chain index."""
|
"""Calculates the auth chain IDs using the chain index."""
|
||||||
|
|
||||||
|
@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
chains: Dict[int, int] = {}
|
chains: Dict[int, int] = {}
|
||||||
|
|
||||||
# Add all linked chains reachable from initial set of chains.
|
# Add all linked chains reachable from initial set of chains.
|
||||||
for batch in batch_iter(event_chains, 1000):
|
for batch2 in batch_iter(event_chains, 1000):
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "origin_chain_id", batch
|
txn.database_engine, "origin_chain_id", batch2
|
||||||
)
|
)
|
||||||
txn.execute(sql % (clause,), args)
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
front = set(event_ids)
|
front = set(event_ids)
|
||||||
while front:
|
while front:
|
||||||
new_front = set()
|
new_front: Set[str] = set()
|
||||||
for chunk in batch_iter(front, 100):
|
for chunk in batch_iter(front, 100):
|
||||||
# Pull the auth events either from the cache or DB.
|
# Pull the auth events either from the cache or DB.
|
||||||
to_fetch: List[str] = [] # Event IDs to fetch from DB
|
to_fetch: List[str] = [] # Event IDs to fetch from DB
|
||||||
|
@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# Note we need to batch up the results by event ID before
|
# Note we need to batch up the results by event ID before
|
||||||
# adding to the cache.
|
# adding to the cache.
|
||||||
to_cache = {}
|
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||||
for event_id, auth_event_id, auth_event_depth in txn:
|
for event_id, auth_event_id, auth_event_depth in txn:
|
||||||
to_cache.setdefault(event_id, []).append(
|
to_cache.setdefault(event_id, []).append(
|
||||||
(auth_event_id, auth_event_depth)
|
(auth_event_id, auth_event_depth)
|
||||||
|
@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# Check if we have indexed the room so we can use the chain cover
|
# Check if we have indexed the room so we can use the chain cover
|
||||||
# algorithm.
|
# algorithm.
|
||||||
room = await self.get_room(room_id)
|
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||||
if room["has_auth_chain_index"]:
|
if room["has_auth_chain_index"]:
|
||||||
try:
|
try:
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
|
@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain_difference_using_cover_index_txn(
|
def _get_auth_chain_difference_using_cover_index_txn(
|
||||||
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
|
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Calculates the auth chain difference using the chain index.
|
"""Calculates the auth chain difference using the chain index.
|
||||||
|
|
||||||
|
@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
||||||
# the loop)
|
# the loop)
|
||||||
for batch in batch_iter(set(seen_chains), 1000):
|
for batch2 in batch_iter(set(seen_chains), 1000):
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
txn.database_engine, "origin_chain_id", batch
|
txn.database_engine, "origin_chain_id", batch2
|
||||||
)
|
)
|
||||||
txn.execute(sql % (clause,), args)
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_auth_chain_difference_txn(
|
def _get_auth_chain_difference_txn(
|
||||||
self, txn, state_sets: List[Set[str]]
|
self, txn: LoggingTransaction, state_sets: List[Set[str]]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Calculates the auth chain difference using a breadth first search.
|
"""Calculates the auth chain difference using a breadth first search.
|
||||||
|
|
||||||
|
@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
# I think building a temporary list with fetchall is more efficient than
|
# I think building a temporary list with fetchall is more efficient than
|
||||||
# just `search.extend(txn)`, but this is unconfirmed
|
# just `search.extend(txn)`, but this is unconfirmed
|
||||||
search.extend(txn.fetchall())
|
search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
|
||||||
|
|
||||||
# sort by depth
|
# sort by depth
|
||||||
search.sort()
|
search.sort()
|
||||||
|
@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# We parse the results and add the to the `found` set and the
|
# We parse the results and add the to the `found` set and the
|
||||||
# cache (note we need to batch up the results by event ID before
|
# cache (note we need to batch up the results by event ID before
|
||||||
# adding to the cache).
|
# adding to the cache).
|
||||||
to_cache = {}
|
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||||
for event_id, auth_event_id, auth_event_depth in txn:
|
for event_id, auth_event_id, auth_event_depth in txn:
|
||||||
to_cache.setdefault(event_id, []).append(
|
to_cache.setdefault(event_id, []).append(
|
||||||
(auth_event_id, auth_event_depth)
|
(auth_event_id, auth_event_depth)
|
||||||
|
@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||||
|
|
||||||
async def get_oldest_event_ids_with_depth_in_room(
|
async def get_oldest_event_ids_with_depth_in_room(
|
||||||
self, room_id
|
self, room_id: str
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
"""Gets the oldest events(backwards extremities) in the room along with the
|
"""Gets the oldest events(backwards extremities) in the room along with the
|
||||||
aproximate depth.
|
aproximate depth.
|
||||||
|
@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
List of (event_id, depth) tuples
|
List of (event_id, depth) tuples
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
|
def get_oldest_event_ids_with_depth_in_room_txn(
|
||||||
|
txn: LoggingTransaction, room_id: str
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
# Assemble a dictionary with event_id -> depth for the oldest events
|
# Assemble a dictionary with event_id -> depth for the oldest events
|
||||||
# we know of in the room. Backwards extremeties are the oldest
|
# we know of in the room. Backwards extremeties are the oldest
|
||||||
# events we know of in the room but we only know of them because
|
# events we know of in the room but we only know of them because
|
||||||
|
@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
txn.execute(sql, (room_id, False))
|
txn.execute(sql, (room_id, False))
|
||||||
|
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_oldest_event_ids_with_depth_in_room",
|
"get_oldest_event_ids_with_depth_in_room",
|
||||||
|
@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_insertion_event_backward_extremities_in_room(
|
async def get_insertion_event_backward_extremities_in_room(
|
||||||
self, room_id
|
self, room_id: str
|
||||||
) -> List[Tuple[str, int]]:
|
) -> List[Tuple[str, int]]:
|
||||||
"""Get the insertion events we know about that we haven't backfilled yet.
|
"""Get the insertion events we know about that we haven't backfilled yet.
|
||||||
|
|
||||||
|
@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
List of (event_id, depth) tuples
|
List of (event_id, depth) tuples
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
|
def get_insertion_event_backward_extremities_in_room_txn(
|
||||||
|
txn: LoggingTransaction, room_id: str
|
||||||
|
) -> List[Tuple[str, int]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
|
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
|
||||||
/* We only want insertion events that are also marked as backwards extremities */
|
/* We only want insertion events that are also marked as backwards extremities */
|
||||||
|
@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_insertion_event_backward_extremities_in_room",
|
"get_insertion_event_backward_extremities_in_room",
|
||||||
|
@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||||
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
|
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
return max_depth_event_id, current_max_depth
|
return max_depth_event_id, current_max_depth
|
||||||
|
|
||||||
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||||
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
|
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
|
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prev_events_for_room_txn(self, txn, room_id: str):
|
def _get_prev_events_for_room_txn(
|
||||||
|
self, txn: LoggingTransaction, room_id: str
|
||||||
|
) -> List[str]:
|
||||||
# we just use the 10 newest events. Older events will become
|
# we just use the 10 newest events. Older events will become
|
||||||
# prev_events of future events.
|
# prev_events of future events.
|
||||||
|
|
||||||
|
@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
sorted by extremity count.
|
sorted by extremity count.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_rooms_with_many_extremities_txn(txn):
|
def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
where_clause = "1=1"
|
where_clause = "1=1"
|
||||||
if room_id_filter:
|
if room_id_filter:
|
||||||
where_clause = "room_id NOT IN (%s)" % (
|
where_clause = "room_id NOT IN (%s)" % (
|
||||||
|
@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_min_depth_interaction(self, txn, room_id):
|
def _get_min_depth_interaction(
|
||||||
|
self, txn: LoggingTransaction, room_id: str
|
||||||
|
) -> Optional[int]:
|
||||||
min_depth = self.db_pool.simple_select_one_onecol_txn(
|
min_depth = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table="room_depth",
|
table="room_depth",
|
||||||
|
@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
# We want to make the cache more effective, so we clamp to the last
|
# We want to make the cache more effective, so we clamp to the last
|
||||||
# change before the given ordering.
|
# change before the given ordering.
|
||||||
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
|
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# We don't always have a full stream_to_exterm_id table, e.g. after
|
# We don't always have a full stream_to_exterm_id table, e.g. after
|
||||||
# the upgrade that introduced it, so we make sure we never ask for a
|
# the upgrade that introduced it, so we make sure we never ask for a
|
||||||
# stream_ordering from before a restart
|
# stream_ordering from before a restart
|
||||||
last_change = max(self._stream_order_on_start, last_change)
|
last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# provided the last_change is recent enough, we now clamp the requested
|
# provided the last_change is recent enough, we now clamp the requested
|
||||||
# stream_ordering to it.
|
# stream_ordering to it.
|
||||||
if last_change > self.stream_ordering_month_ago:
|
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||||
stream_ordering = min(last_change, stream_ordering)
|
stream_ordering = min(last_change, stream_ordering)
|
||||||
|
|
||||||
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
||||||
|
|
||||||
@cached(max_entries=5000, num_args=2)
|
@cached(max_entries=5000, num_args=2)
|
||||||
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
async def _get_forward_extremeties_for_room(
|
||||||
|
self, room_id: str, stream_ordering: int
|
||||||
|
) -> List[str]:
|
||||||
"""For a given room_id and stream_ordering, return the forward
|
"""For a given room_id and stream_ordering, return the forward
|
||||||
extremeties of the room at that point in "time".
|
extremeties of the room at that point in "time".
|
||||||
|
|
||||||
|
@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
stream_orderings from that point.
|
stream_orderings from that point.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if stream_ordering <= self.stream_ordering_month_ago:
|
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
WHERE room_id = ?
|
WHERE room_id = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_forward_extremeties_for_room_txn(txn):
|
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
|
||||||
txn.execute(sql, (stream_ordering, room_id))
|
txn.execute(sql, (stream_ordering, room_id))
|
||||||
return [event_id for event_id, in txn]
|
return [event_id for event_id, in txn]
|
||||||
|
|
||||||
|
@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_backfill_events(
|
async def get_backfill_events(
|
||||||
self, room_id: str, seed_event_id_list: list, limit: int
|
self, room_id: str, seed_event_id_list: List[str], limit: int
|
||||||
):
|
) -> List[EventBase]:
|
||||||
"""Get a list of Events for a given topic that occurred before (and
|
"""Get a list of Events for a given topic that occurred before (and
|
||||||
including) the events in seed_event_id_list. Return a list of max size `limit`
|
including) the events in seed_event_id_list. Return a list of max size `limit`
|
||||||
|
|
||||||
|
@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
events = await self.get_events_as_list(event_ids)
|
events = await self.get_events_as_list(event_ids)
|
||||||
return sorted(
|
return sorted(
|
||||||
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
|
# type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
|
||||||
|
# But it's never None, because these events were previously persisted to the DB.
|
||||||
|
events,
|
||||||
|
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
|
def _get_backfill_events(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
room_id: str,
|
||||||
|
seed_event_id_list: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
We want to make sure that we do a breadth-first, "depth" ordered search.
|
We want to make sure that we do a breadth-first, "depth" ordered search.
|
||||||
We also handle navigating historical branches of history connected by
|
We also handle navigating historical branches of history connected by
|
||||||
|
@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
event_id_results = set()
|
event_id_results: Set[str] = set()
|
||||||
|
|
||||||
# In a PriorityQueue, the lowest valued entries are retrieved first.
|
# In a PriorityQueue, the lowest valued entries are retrieved first.
|
||||||
# We're using depth as the priority in the queue and tie-break based on
|
# We're using depth as the priority in the queue and tie-break based on
|
||||||
|
@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
# highest and newest-in-time message. We add events to the queue with a
|
# highest and newest-in-time message. We add events to the queue with a
|
||||||
# negative depth so that we process the newest-in-time messages first
|
# negative depth so that we process the newest-in-time messages first
|
||||||
# going backwards in time. stream_ordering follows the same pattern.
|
# going backwards in time. stream_ordering follows the same pattern.
|
||||||
queue = PriorityQueue()
|
queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
|
||||||
|
|
||||||
for seed_event_id in seed_event_id_list:
|
for seed_event_id in seed_event_id_list:
|
||||||
event_lookup_result = self.db_pool.simple_select_one_txn(
|
event_lookup_result = self.db_pool.simple_select_one_txn(
|
||||||
|
@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
return event_id_results
|
return event_id_results
|
||||||
|
|
||||||
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
async def get_missing_events(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
earliest_events: List[str],
|
||||||
|
latest_events: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> List[EventBase]:
|
||||||
ids = await self.db_pool.runInteraction(
|
ids = await self.db_pool.runInteraction(
|
||||||
"get_missing_events",
|
"get_missing_events",
|
||||||
self._get_missing_events,
|
self._get_missing_events,
|
||||||
|
@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
return await self.get_events_as_list(ids)
|
return await self.get_events_as_list(ids)
|
||||||
|
|
||||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
def _get_missing_events(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
room_id: str,
|
||||||
|
earliest_events: List[str],
|
||||||
|
latest_events: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> List[str]:
|
||||||
|
|
||||||
seen_events = set(earliest_events)
|
seen_events = set(earliest_events)
|
||||||
front = set(latest_events) - seen_events
|
front = set(latest_events) - seen_events
|
||||||
event_results = []
|
event_results: List[str] = []
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"SELECT prev_event_id FROM event_edges "
|
"SELECT prev_event_id FROM event_edges "
|
||||||
|
@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
@wrap_as_background_process("delete_old_forward_extrem_cache")
|
@wrap_as_background_process("delete_old_forward_extrem_cache")
|
||||||
async def _delete_old_forward_extrem_cache(self) -> None:
|
async def _delete_old_forward_extrem_cache(self) -> None:
|
||||||
def _delete_old_forward_extrem_cache_txn(txn):
|
def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
|
||||||
# Delete entries older than a month, while making sure we don't delete
|
# Delete entries older than a month, while making sure we don't delete
|
||||||
# the only entries for a room.
|
# the only entries for a room.
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
) AND stream_ordering < ?
|
) AND stream_ordering < ?
|
||||||
"""
|
"""
|
||||||
txn.execute(
|
txn.execute(
|
||||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
|
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
if self.db_pool.engine.supports_returning:
|
if self.db_pool.engine.supports_returning:
|
||||||
|
|
||||||
def _remove_received_event_from_staging_txn(txn):
|
def _remove_received_event_from_staging_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[int]:
|
||||||
sql = """
|
sql = """
|
||||||
DELETE FROM federation_inbound_events_staging
|
DELETE FROM federation_inbound_events_staging
|
||||||
WHERE origin = ? AND event_id = ?
|
WHERE origin = ? AND event_id = ?
|
||||||
|
@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (origin, event_id))
|
txn.execute(sql, (origin, event_id))
|
||||||
return txn.fetchone()
|
row = cast(Optional[Tuple[int]], txn.fetchone())
|
||||||
|
|
||||||
row = await self.db_pool.runInteraction(
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
"remove_received_event_from_staging",
|
"remove_received_event_from_staging",
|
||||||
_remove_received_event_from_staging_txn,
|
_remove_received_event_from_staging_txn,
|
||||||
db_autocommit=True,
|
db_autocommit=True,
|
||||||
)
|
)
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return row[0]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def _remove_received_event_from_staging_txn(txn):
|
def _remove_received_event_from_staging_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[int]:
|
||||||
received_ts = self.db_pool.simple_select_one_onecol_txn(
|
received_ts = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
table="federation_inbound_events_staging",
|
table="federation_inbound_events_staging",
|
||||||
|
@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
) -> Optional[Tuple[str, str]]:
|
) -> Optional[Tuple[str, str]]:
|
||||||
"""Get the next event ID in the staging area for the given room."""
|
"""Get the next event ID in the staging area for the given room."""
|
||||||
|
|
||||||
def _get_next_staged_event_id_for_room_txn(txn):
|
def _get_next_staged_event_id_for_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Tuple[str, str]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT origin, event_id
|
SELECT origin, event_id
|
||||||
FROM federation_inbound_events_staging
|
FROM federation_inbound_events_staging
|
||||||
|
@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
|
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
|
|
||||||
return txn.fetchone()
|
return cast(Optional[Tuple[str, str]], txn.fetchone())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
|
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
|
||||||
|
@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
) -> Optional[Tuple[str, EventBase]]:
|
) -> Optional[Tuple[str, EventBase]]:
|
||||||
"""Get the next event in the staging area for the given room."""
|
"""Get the next event in the staging area for the given room."""
|
||||||
|
|
||||||
def _get_next_staged_event_for_room_txn(txn):
|
def _get_next_staged_event_for_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Tuple[str, str, str]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT event_json, internal_metadata, origin
|
SELECT event_json, internal_metadata, origin
|
||||||
FROM federation_inbound_events_staging
|
FROM federation_inbound_events_staging
|
||||||
|
@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (room_id,))
|
txn.execute(sql, (room_id,))
|
||||||
|
|
||||||
return txn.fetchone()
|
return cast(Optional[Tuple[str, str, str]], txn.fetchone())
|
||||||
|
|
||||||
row = await self.db_pool.runInteraction(
|
row = await self.db_pool.runInteraction(
|
||||||
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
|
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
|
||||||
|
@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("_get_stats_for_federation_staging")
|
@wrap_as_background_process("_get_stats_for_federation_staging")
|
||||||
async def _get_stats_for_federation_staging(self):
|
async def _get_stats_for_federation_staging(self) -> None:
|
||||||
"""Update the prometheus metrics for the inbound federation staging area."""
|
"""Update the prometheus metrics for the inbound federation staging area."""
|
||||||
|
|
||||||
def _get_stats_for_federation_staging_txn(txn):
|
def _get_stats_for_federation_staging_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
|
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
|
||||||
(count,) = txn.fetchone()
|
(count,) = cast(Tuple[int], txn.fetchone())
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT min(received_ts) FROM federation_inbound_events_staging"
|
"SELECT min(received_ts) FROM federation_inbound_events_staging"
|
||||||
)
|
)
|
||||||
|
|
||||||
(received_ts,) = txn.fetchone()
|
(received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
|
||||||
|
|
||||||
# If there is nothing in the staging area default it to 0.
|
# If there is nothing in the staging area default it to 0.
|
||||||
age = 0
|
age = 0
|
||||||
|
@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
||||||
)
|
)
|
||||||
|
|
||||||
async def clean_room_for_join(self, room_id):
|
async def clean_room_for_join(self, room_id: str) -> None:
|
||||||
return await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _clean_room_for_join_txn(self, txn, room_id):
|
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||||
|
|
||||||
async def _background_delete_non_state_event_auth(self, progress, batch_size):
|
async def _background_delete_non_state_event_auth(
|
||||||
def delete_event_auth(txn):
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
|
def delete_event_auth(txn: LoggingTransaction) -> bool:
|
||||||
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
||||||
max_stream_id = progress.get("max_stream_id_exclusive")
|
max_stream_id = progress.get("max_stream_id_exclusive")
|
||||||
|
|
||||||
|
|
|
@ -332,6 +332,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
most_recent_prev_event_depth,
|
most_recent_prev_event_depth,
|
||||||
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
|
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
|
||||||
# mapping from (type, state_key) -> state_event_id
|
# mapping from (type, state_key) -> state_event_id
|
||||||
|
assert most_recent_prev_event_id is not None
|
||||||
prev_state_map = self.get_success(
|
prev_state_map = self.get_success(
|
||||||
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
|
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue