0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-02 10:48:56 +02:00

Add type hints to synapse/storage/databases/main/events_worker.py (#11411)

Also refactor the stream ID trackers/generators a bit and try to
document them better.
This commit is contained in:
Sean Quah 2021-11-26 18:41:31 +00:00 committed by GitHub
parent 1d8b80b334
commit ffd858aa68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 255 additions and 171 deletions

1
changelog.d/11411.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to storage classes.

View file

@ -33,7 +33,6 @@ exclude = (?x)
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
@ -184,6 +183,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True

View file

@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
class SlavedIdTracker:
class SlavedIdTracker(AbstractStreamIdTracker):
"""Tracks the "current" stream ID of a stream with a single writer.
See `AbstractStreamIdTracker` for more details.
Note that this class does not work correctly when there are multiple
writers.
"""
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
"""
Returns:
int
"""
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:

View file

@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr
@ -157,7 +157,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)
@ -191,7 +191,7 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)

View file

@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database

View file

@ -17,6 +17,7 @@ import logging
from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:

View file

@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.types import StreamToken, get_domain_from_id
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
token: StreamToken,
token: int,
rows: Iterable[Any],
) -> None:
pass

View file

@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
from collections import OrderedDict, namedtuple
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@ -64,9 +65,6 @@ event_counter = Counter(
)
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@ -108,16 +106,21 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
# Since we have been configured to write, we ought to have id generators,
# rather than id trackers.
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@ -1553,11 +1556,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)
txn.call_after(prefill)

View file

@ -15,14 +15,18 @@
import logging
import threading
from typing import (
TYPE_CHECKING,
Any,
Collection,
Container,
Dict,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
cast,
overload,
)
@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@ -69,6 +82,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._stream_id_gen: AbstractStreamIdTracker
self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
self._get_event_cache = LruCache(
self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_list: List[
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[False] = ...,
check_room_id: Optional[str] = ...,
) -> EventBase:
...
@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[True] = ...,
check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
event_ids: Iterable[str],
event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
ObservableDeferred[Dict[str, _EventCacheEntry]]
ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@ -601,7 +632,7 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@ -820,7 +851,7 @@ class EventsWorkerStore(SQLBaseStore):
for _, deferred in event_fetches_to_fail:
deferred.errback(exc)
def _fetch_loop(self, conn: Connection) -> None:
def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
@ -850,7 +881,9 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
self,
conn: LoggingDatabaseConnection,
event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@ -877,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
def fire():
def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@ -887,16 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs, exc):
for _, d in evs:
def fire_errback(exc: Exception) -> None:
for _, d in event_list:
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
self, event_ids: Collection[str]
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@ -912,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
event_map = {}
event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@ -962,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@ -1032,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
cache_entry = _EventCacheEntry(
cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@ -1048,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -1061,7 +1095,7 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
events_d = defer.Deferred()
events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
@ -1216,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
async def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@ -1245,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
set[str]: The events we have already seen.
The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@ -1268,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
def have_seen_events_txn(
txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@ -1294,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str):
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
def _get_current_state_event_counts_txn(self, txn, room_id):
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
) -> int:
"""
See get_current_state_event_counts.
"""
@ -1324,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
async def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@ -1332,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
room_id (str)
room_id: The room ID to query.
Returns:
dict[str:int] of complexity version to complexity.
dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@ -1345,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
def get_current_events_token(self):
def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@ -1365,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_all_new_forward_event_rows(txn):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@ -1381,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@ -1389,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@ -1402,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@ -1420,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@ -1428,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@ -1456,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_new_backfill_event_rows(txn):
def get_all_new_backfill_event_rows(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@ -1470,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
new_event_updates: List[
Tuple[int, Tuple[str, str, str, str, str, str]]
] = []
row: Tuple[int, str, str, str, str, str, str]
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@ -1493,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@ -1507,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@ -1527,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn):
def get_all_updated_current_state_deltas_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@ -1536,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
def get_deltas_for_stream_id_txn(txn, stream_id):
def get_deltas_for_stream_id_txn(
txn: LoggingTransaction, stream_id: int
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows: List[Tuple] = await self.db_pool.runInteraction(
rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@ -1579,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@ -1609,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
def get_next_event_to_expire_txn(txn):
def get_next_event_to_expire_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@ -1617,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
return txn.fetchone()
return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@ -1681,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?

View file

@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen: Union[
StreamIdGenerator, SlavedIdTracker
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, "push_rules_stream", "stream_id"
)
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"

View file

@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
raise NotImplementedError()
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
"""Tracks the "current" stream ID of a stream that may have multiple writers.
Stream IDs are monotonically increasing or decreasing integers representing write
transactions. The "current" stream ID is the stream ID such that all transactions
with equal or smaller stream IDs have completed. Since transactions may complete out
of order, this is not the same as the stream ID of the last completed transaction.
Completed transactions include both committed transactions and transactions that
have been rolled back.
"""
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to `get_current_token`.
"""
raise NotImplementedError()
class AbstractStreamIdGenerator(AbstractStreamIdTracker):
"""Generates stream IDs for a stream that may have multiple writers.
Each stream ID represents a write transaction, whose completion is tracked
so that the "current" stream ID of the stream can be determined.
See `AbstractStreamIdTracker` for more details.
"""
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
raise NotImplementedError()
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
"""Generates and tracks stream IDs for a stream with a single writer.
This allows us to get the "current" stream id, i.e. the stream id such that
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
This class must only be used when the current Synapse process is the sole
writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
# `StreamIdGenerator` should only be used when there is a single writer,
# so replication should never happen.
raise Exception("Replication is not supported by StreamIdGenerator")
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
self._current += self._step
next_id = self._current
@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
next_ids = range(
self._current + self._step,
@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""An ID generator that tracks a stream that can have multiple writers.
"""Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
new_id *= self._return_factor
with self._lock:

View file

@ -17,6 +17,7 @@ from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client import login, room, sync
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
worker_store2 = worker_hs2.get_datastore()
assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
actx = worker_store2._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)