0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 16:53:53 +01:00

Add StreamStore to mypy (#8232)

This commit is contained in:
Erik Johnston 2020-09-02 17:52:38 +01:00 committed by GitHub
parent 5a1dd297c3
commit 112266eafd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 20 deletions

1
changelog.d/8232.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `StreamStore`.

View file

@ -43,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,

View file

@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
order = DictProperty("order") # type: int
order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict:
return dict(self._dict)

View file

@ -604,6 +604,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
@ -1088,6 +1100,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...
async def simple_select_one_onecol(
self,
table: str,

View file

@ -39,7 +39,7 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@ -54,9 +54,12 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -206,7 +209,7 @@ def _make_generic_sql_bound(
)
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(self, room_ids, from_key):
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
room_ids
from_key: The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
async def get_room_events_stream_for_room(
@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos",
)
def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,