mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 03:33:48 +01:00
Add StreamStore to mypy (#8232)
This commit is contained in:
parent
5a1dd297c3
commit
112266eafd
5 changed files with 66 additions and 20 deletions
1
changelog.d/8232.misc
Normal file
1
changelog.d/8232.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Add type hints to `StreamStore`.
|
1
mypy.ini
1
mypy.ini
|
@ -43,6 +43,7 @@ files =
|
|||
synapse/server_notices,
|
||||
synapse/spam_checker_api,
|
||||
synapse/state,
|
||||
synapse/storage/databases/main/stream.py,
|
||||
synapse/storage/databases/main/ui_auth.py,
|
||||
synapse/storage/database.py,
|
||||
synapse/storage/engines,
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import abc
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
|
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
|
|||
# be here
|
||||
before = DictProperty("before") # type: str
|
||||
after = DictProperty("after") # type: str
|
||||
order = DictProperty("order") # type: int
|
||||
order = DictProperty("order") # type: Tuple[int, int]
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
return dict(self._dict)
|
||||
|
|
|
@ -604,6 +604,18 @@ class DatabasePool(object):
|
|||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Literal[None], query: str, *args: Any
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
|
||||
) -> R:
|
||||
...
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
desc: str,
|
||||
|
@ -1088,6 +1100,28 @@ class DatabasePool(object):
|
|||
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
@overload
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> Optional[Any]:
|
||||
...
|
||||
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
@ -39,7 +39,7 @@ what sort order was used:
|
|||
import abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -54,9 +54,12 @@ from synapse.storage.database import (
|
|||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.types import Collection, RoomStreamToken
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -206,7 +209,7 @@ def _make_generic_sql_bound(
|
|||
)
|
||||
|
||||
|
||||
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
|
||||
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||
# have indices on all possible clauses). E.g. it may create
|
||||
# "room_id == X AND room_id != X", which postgres doesn't optimise.
|
||||
|
@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_room_max_stream_ordering(self):
|
||||
def get_room_max_stream_ordering(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_room_min_stream_ordering(self):
|
||||
def get_room_min_stream_ordering(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
room_ids: Iterable[str],
|
||||
room_ids: Collection[str],
|
||||
from_key: str,
|
||||
to_key: str,
|
||||
limit: int = 0,
|
||||
|
@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
def get_rooms_that_changed(self, room_ids, from_key):
|
||||
def get_rooms_that_changed(
|
||||
self, room_ids: Collection[str], from_key: str
|
||||
) -> Set[str]:
|
||||
"""Given a list of rooms and a token, return rooms where there may have
|
||||
been changes.
|
||||
|
||||
Args:
|
||||
room_ids (list)
|
||||
from_key (str): The room_key portion of a StreamToken
|
||||
room_ids
|
||||
from_key: The room_key portion of a StreamToken
|
||||
"""
|
||||
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
return {
|
||||
room_id
|
||||
for room_id in room_ids
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||
}
|
||||
|
||||
async def get_room_events_stream_for_room(
|
||||
|
@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return ret, key
|
||||
|
||||
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
||||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: str, to_key: str
|
||||
) -> List[EventBase]:
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||
|
||||
|
@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
)
|
||||
return row[0][0] if row else 0
|
||||
|
||||
def _get_max_topological_txn(self, txn, room_id):
|
||||
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
|
||||
txn.execute(
|
||||
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
|
||||
(room_id,),
|
||||
|
@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
def _get_events_around_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
before_limit: int,
|
||||
|
@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
retcols=["stream_ordering", "topological_ordering"],
|
||||
)
|
||||
|
||||
# This cannot happen as `allow_none=False`.
|
||||
assert results is not None
|
||||
|
||||
# Paginating backwards includes the event at the token, but paginating
|
||||
# forward doesn't.
|
||||
before_token = RoomStreamToken(
|
||||
|
@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
desc="update_federation_out_pos",
|
||||
)
|
||||
|
||||
def _reset_federation_positions_txn(self, txn) -> None:
|
||||
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
|
||||
"""Fiddles with the `federation_stream_position` table to make it match
|
||||
the configured federation sender instances during start up.
|
||||
"""
|
||||
|
@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
GROUP BY type
|
||||
"""
|
||||
txn.execute(sql)
|
||||
min_positions = dict(txn) # Map from type -> min position
|
||||
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
|
||||
|
||||
# Ensure we do actually have some values here
|
||||
assert set(min_positions) == {"federation", "events"}
|
||||
|
@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
def _paginate_room_events_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
from_token: RoomStreamToken,
|
||||
to_token: Optional[RoomStreamToken] = None,
|
||||
|
|
Loading…
Reference in a new issue