0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 04:33:47 +01:00

Add some type hints to datastore. (#12477)

This commit is contained in:
Dirk Klimpel 2022-05-10 20:07:48 +02:00 committed by GitHub
parent 147f098fb4
commit 989fa33096
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 71 deletions

1
changelog.d/12477.misc Normal file
View file

@ -0,0 +1 @@
Add some type hints to datastore.

View file

@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr
from frozendict import frozendict
from typing_extensions import Literal
from twisted.internet.defer import Deferred
@ -106,7 +107,7 @@ class EventContext:
incomplete state.
"""
rejected: Union[bool, str] = False
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
prev_group: Optional[int] = None

View file

@ -49,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@ -235,7 +235,9 @@ class PersistEventsStore:
"""
results: List[str] = []
def _get_events_which_are_prevs_txn(txn, batch):
def _get_events_which_are_prevs_txn(
txn: LoggingTransaction, batch: Collection[str]
) -> None:
sql = """
SELECT prev_event_id, internal_metadata
FROM event_edges
@ -285,7 +287,9 @@ class PersistEventsStore:
# and their prev events.
existing_prevs = set()
def _get_prevs_before_rejected_txn(txn, batch):
def _get_prevs_before_rejected_txn(
txn: LoggingTransaction, batch: Collection[str]
) -> None:
to_recursively_check = batch
while to_recursively_check:
@ -515,7 +519,7 @@ class PersistEventsStore:
@classmethod
def _add_chain_cover_index(
cls,
txn,
txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@ -809,7 +813,7 @@ class PersistEventsStore:
@staticmethod
def _allocate_chain_ids(
txn,
txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@ -943,7 +947,7 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
@ -997,7 +1001,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
):
) -> None:
for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@ -1155,7 +1159,7 @@ class PersistEventsStore:
txn, room_id, members_changed
)
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
events.
@ -1189,7 +1193,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
new_forward_extremities: Dict[str, Set[str]],
max_stream_order: int,
):
) -> None:
for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@ -1254,9 +1258,9 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
txn,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
) -> None:
"""Update min_depth for each room
Args:
@ -1385,7 +1389,7 @@ class PersistEventsStore:
# nothing to do here
return
def event_dict(event):
def event_dict(event: EventBase) -> JsonDict:
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
@ -1476,18 +1480,20 @@ class PersistEventsStore:
),
)
def _store_rejected_events_txn(self, txn, events_and_contexts):
def _store_rejected_events_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> List[Tuple[EventBase, EventContext]]:
"""Add rows to the 'rejections' table for received events which were
rejected
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
txn: db connection
events_and_contexts: events we are persisting
Returns:
list[(EventBase, EventContext)] new list, without the rejected
events.
new list, without the rejected events.
"""
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
@ -1508,7 +1514,7 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
):
) -> None:
"""Update all the miscellaneous tables for new events
Args:
@ -1602,7 +1608,11 @@ class PersistEventsStore:
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
def _add_to_cache(self, txn, events_and_contexts):
def _add_to_cache(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
to_prefill = []
rows = []
@ -1633,7 +1643,7 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
def prefill() -> None:
for cache_entry in to_prefill:
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
@ -1663,19 +1673,24 @@ class PersistEventsStore:
)
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):
self,
txn: LoggingTransaction,
event_id: str,
labels: List[str],
room_id: str,
topological_ordering: int,
) -> None:
"""Store the mapping between an event's ID and its labels, with one row per
(event_id, label) tuple.
Args:
txn (LoggingTransaction): The transaction to execute.
event_id (str): The event's ID.
labels (list[str]): A list of text labels.
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
txn: The transaction to execute.
event_id: The event's ID.
labels: A list of text labels.
room_id: The ID of the room the event was sent to.
topological_ordering: The position of the event in the room's topology.
"""
return self.db_pool.simple_insert_many_txn(
self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
keys=("event_id", "label", "room_id", "topological_ordering"),
@ -1684,25 +1699,32 @@ class PersistEventsStore:
],
)
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
def _insert_event_expiry_txn(
self, txn: LoggingTransaction, event_id: str, expiry_ts: int
) -> None:
"""Save the expiry timestamp associated with a given event ID.
Args:
txn (LoggingTransaction): The database transaction to use.
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
txn: The database transaction to use.
event_id: The event ID the expiry timestamp is associated with.
expiry_ts: The timestamp at which to expire (delete) the event.
"""
return self.db_pool.simple_insert_txn(
self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
)
def _store_room_members_txn(
self, txn, events, *, inhibit_local_membership_updates: bool = False
):
self,
txn: LoggingTransaction,
events: List[EventBase],
*,
inhibit_local_membership_updates: bool = False,
) -> None:
"""
Store a room member in the database.
Args:
txn: The transaction to use.
events: List of events to store.
@ -1742,6 +1764,7 @@ class PersistEventsStore:
)
for event in events:
assert event.internal_metadata.stream_ordering is not None
txn.call_after(
self.store._membership_stream_cache.entity_has_changed,
event.state_key,
@ -1838,7 +1861,9 @@ class PersistEventsStore:
(parent_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@ -1899,7 +1924,7 @@ class PersistEventsStore:
},
)
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
@ -1999,25 +2024,29 @@ class PersistEventsStore:
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str):
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("name"), str):
self.store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
def _store_room_message_txn(
self, txn: LoggingTransaction, event: EventBase
) -> None:
if isinstance(event.content.get("body"), str):
self.store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
def _store_retention_policy_for_room_txn(self, txn, event):
def _store_retention_policy_for_room_txn(
self, txn: LoggingTransaction, event: EventBase
) -> None:
if not event.is_state():
logger.debug("Ignoring non-state m.room.retention event")
return
@ -2077,8 +2106,11 @@ class PersistEventsStore:
)
def _set_push_actions_for_event_and_users_txn(
self, txn, events_and_contexts, all_events_and_contexts
):
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@ -2086,12 +2118,10 @@ class PersistEventsStore:
from the push action staging area.
Args:
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
all_events_and_contexts (list[(EventBase, EventContext)]): all
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
events_and_contexts: events we are persisting
all_events_and_contexts: all events that we were going to persist.
This includes events we've already persisted, etc, that wouldn't
appear in events_and_context.
"""
# Only non outlier events will have push actions associated with them,
@ -2160,7 +2190,9 @@ class PersistEventsStore:
),
)
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
def _remove_push_actions_for_event_id_txn(
self, txn: LoggingTransaction, room_id: str, event_id: str
) -> None:
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@ -2171,7 +2203,9 @@ class PersistEventsStore:
(room_id, event_id),
)
def _store_rejections_txn(self, txn, event_id, reason):
def _store_rejections_txn(
self, txn: LoggingTransaction, event_id: str, reason: str
) -> None:
self.db_pool.simple_insert_txn(
txn,
table="rejections",
@ -2183,8 +2217,10 @@ class PersistEventsStore:
)
def _store_event_state_mappings_txn(
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
):
self,
txn: LoggingTransaction,
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
) -> None:
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@ -2241,7 +2277,9 @@ class PersistEventsStore:
state_group_id,
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
def _update_min_depth_for_room_txn(
self, txn: LoggingTransaction, room_id: str, depth: int
) -> None:
min_depth = self.store._get_min_depth_interaction(txn, room_id)
if min_depth is not None and depth >= min_depth:
@ -2254,7 +2292,9 @@ class PersistEventsStore:
values={"min_depth": depth},
)
def _handle_mult_prev_events(self, txn, events):
def _handle_mult_prev_events(
self, txn: LoggingTransaction, events: List[EventBase]
) -> None:
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@ -2272,7 +2312,9 @@ class PersistEventsStore:
self._update_backward_extremeties(txn, events)
def _update_backward_extremeties(self, txn, events):
def _update_backward_extremeties(
self, txn: LoggingTransaction, events: List[EventBase]
) -> None:
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.

View file

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr
@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
)
async def _background_reindex_search(self, progress, batch_size):
async def _background_reindex_search(
self, progress: JsonDict, batch_size: int
) -> int:
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result
async def _background_reindex_gin_search(self, progress, batch_size):
async def _background_reindex_gin_search(
self, progress: JsonDict, batch_size: int
) -> int:
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
def create_index(conn):
def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
return 1
async def _background_reindex_search_order(self, progress, batch_size):
async def _background_reindex_search_order(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index:
def create_index(conn):
def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg,
)
def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else:
raise Exception("Unrecognized database engine")
args.append(limit)
# mypy expects to append only a `str`, not an `int`
args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings.
"""
def f(txn):
def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine, search_term):
def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something