Add some type hints to datastore. (#12255)

This commit is contained in:
Dirk Klimpel 2022-03-28 20:11:14 +02:00 committed by GitHub
parent 4ba55a620f
commit ac95167d2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 61 additions and 42 deletions

1
changelog.d/12255.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints for storage.

View file

@ -38,14 +38,11 @@ exclude = (?x)
|synapse/_scripts/update_synapse_database.py |synapse/_scripts/update_synapse_database.py
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/schema/ |synapse/storage/schema/

View file

@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media """Get the metadata for a local piece of media

View file

@ -24,10 +24,9 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
cast,
) )
from twisted.internet import defer
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer", hs: "HomeServer",
): ):
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = ( self._can_write_to_receipts = (
@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?" " AND user_id = ?"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return txn.fetchall() return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f "get_receipts_for_user_with_orderings", f
@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows: if not rows:
return [] return []
content = {} content: JsonDict = {}
for row in rows: for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"] row["user_id"]
@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )
results = {} results: JsonDict = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f "get_linearized_receipts_for_all_rooms", f
) )
results = {} results: JsonDict = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """ sql = """
@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)
limited = False limited = False
upper_bound = current_id upper_bound = current_id
@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type)) self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME: if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token) self._receipts_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
if receipt_type == ReceiptTypes.READ and stream_ordering is not None: if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn( self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
) )
@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
async with self._receipts_id_gen.get_next() as stream_id: async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,

View file

@ -22,6 +22,7 @@ import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config: HomeServerConfig = hs.config
# Note: we don't check this sequence for consistency as we'd have to # Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is

View file

@ -34,6 +34,7 @@ import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config: HomeServerConfig = hs.config
async def store_room( async def store_room(
self, self,

View file

@ -14,7 +14,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
import attr import attr
@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
) )
args = ( args1 = (
( (
entry.event_id, entry.event_id,
entry.room_id, entry.room_id,
@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
"INSERT INTO event_search (event_id, room_id, key, value)" "INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)" " VALUES (?,?,?,?)"
) )
args = ( args2 = (
( (
entry.event_id, entry.event_id,
entry.room_id, entry.room_id,
@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
) )
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.execute_batch(sql, args2)
else: else:
# This should be unreachable. # This should be unreachable.
@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term) search_query = _parse_query(self.database_engine, search_term)
args = [] args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )
@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
room_ids: Collection[str], room_ids: Collection[str],
search_term: str, search_term: str,
keys: Iterable[str], keys: Iterable[str],
limit, limit: int,
pagination_token: Optional[str] = None, pagination_token: Optional[str] = None,
) -> JsonDict: ) -> JsonDict:
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term) search_query = _parse_query(self.database_engine, search_term)
args = [] args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
if pagination_token: if pagination_token:
try: try:
origin_server_ts, stream = pagination_token.split(",") origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts) origin_server_ts = int(origin_server_ts_str)
stream = int(stream) stream = int(stream_str)
except Exception: except Exception:
raise SynapseError(400, "Invalid pagination token") raise SynapseError(400, "Invalid pagination token")
@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
import logging import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version # We delegate to the cached version
return await self.get_current_state_ids(room_id) return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn): def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {} results = {}
sql = """ sql = """
SELECT type, state_key, event_id FROM current_state_events SELECT type, state_key, event_id FROM current_state_events
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, "")) event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id: if not event_id:
return return None
event = await self.get_event(event_id, allow_none=True) event = await self.get_event(event_id, allow_none=True)
if not event: if not event:
return return None
return event.content.get("canonical_alias") return event.content.get("canonical_alias")
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids", list_name="event_ids",
num_args=1, num_args=1,
) )
async def _get_state_group_for_events(self, event_ids): async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group""" """Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME, self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms, self._background_remove_left_rooms,
) )
async def _background_remove_left_rooms(self, progress, batch_size): async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and """Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no `event_forward_extremities` tables of rooms that the server is no
longer joined to. longer joined to.
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "") last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn): def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider # get a batch of room ids to consider
sql = """ sql = """
SELECT DISTINCT room_id FROM current_state_events SELECT DISTINCT room_id FROM current_state_events

View file

@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled self.stats_enabled = hs.config.stats.stats_enabled

View file

@ -68,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) -> None: ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name: str = hs.hostname
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables", "populate_user_directory_createtables",