Add StreamStore to mypy (#8232)

This commit is contained in:
Erik Johnston 2020-09-02 17:52:38 +01:00 committed by GitHub
parent 5a1dd297c3
commit 112266eafd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 20 deletions

1
changelog.d/8232.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `StreamStore`.

View file

@ -43,6 +43,7 @@ files =
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,

View file

@ -18,7 +18,7 @@
import abc import abc
import os import os
from distutils.util import strtobool from distutils.util import strtobool
from typing import Dict, Optional, Type from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
# be here # be here
before = DictProperty("before") # type: str before = DictProperty("before") # type: str
after = DictProperty("after") # type: str after = DictProperty("after") # type: str
order = DictProperty("order") # type: int order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict: def get_dict(self) -> JsonDict:
return dict(self._dict) return dict(self._dict)

View file

@ -604,6 +604,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor] results = [dict(zip(col_headers, row)) for row in cursor]
return results return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute( async def execute(
self, self,
desc: str, desc: str,
@ -1088,6 +1100,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
) )
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...
async def simple_select_one_onecol( async def simple_select_one_onecol(
self, self,
table: str, table: str,

View file

@ -39,7 +39,7 @@ what sort order was used:
import abc import abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -54,9 +54,12 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -206,7 +209,7 @@ def _make_generic_sql_bound(
) )
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't # NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create # have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise. # "room_id == X AND room_id != X", which postgres doesn't optimise.
@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs) super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod @abc.abstractmethod
def get_room_max_stream_ordering(self): def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError() raise NotImplementedError()
async def get_room_events_stream_for_rooms( async def get_room_events_stream_for_rooms(
self, self,
room_ids: Iterable[str], room_ids: Collection[str],
from_key: str, from_key: str,
to_key: str, to_key: str,
limit: int = 0, limit: int = 0,
@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
def get_rooms_that_changed(self, room_ids, from_key): def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have """Given a list of rooms and a token, return rooms where there may have
been changes. been changes.
Args: Args:
room_ids (list) room_ids
from_key (str): The room_key portion of a StreamToken from_key: The room_key portion of a StreamToken
""" """
from_key = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
return { return {
room_id room_id
for room_id in room_ids for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key) if self._events_stream_cache.has_entity_changed(room_id, from_id)
} }
async def get_room_events_stream_for_room( async def get_room_events_stream_for_room(
@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key return ret, key
async def get_membership_changes_for_user(self, user_id, from_key, to_key): async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
return row[0][0] if row else 0 return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id): def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,), (room_id,),
@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn( def _get_events_around_txn(
self, self,
txn, txn: LoggingTransaction,
room_id: str, room_id: str,
event_id: str, event_id: str,
before_limit: int, before_limit: int,
@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"], retcols=["stream_ordering", "topological_ordering"],
) )
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating # Paginating backwards includes the event at the token, but paginating
# forward doesn't. # forward doesn't.
before_token = RoomStreamToken( before_token = RoomStreamToken(
@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos", desc="update_federation_out_pos",
) )
def _reset_federation_positions_txn(self, txn) -> None: def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match """Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up. the configured federation sender instances during start up.
""" """
@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type GROUP BY type
""" """
txn.execute(sql) txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here # Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"} assert set(min_positions) == {"federation", "events"}
@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn( def _paginate_room_events_txn(
self, self,
txn, txn: LoggingTransaction,
room_id: str, room_id: str,
from_token: RoomStreamToken, from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None, to_token: Optional[RoomStreamToken] = None,