mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 19:53:51 +01:00
Add StreamStore to mypy (#8232)
This commit is contained in:
parent
5a1dd297c3
commit
112266eafd
5 changed files with 66 additions and 20 deletions
1
changelog.d/8232.misc
Normal file
1
changelog.d/8232.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `StreamStore`.
|
1
mypy.ini
1
mypy.ini
|
@ -43,6 +43,7 @@ files =
|
||||||
synapse/server_notices,
|
synapse/server_notices,
|
||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
|
synapse/storage/databases/main/stream.py,
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
synapse/storage/database.py,
|
synapse/storage/database.py,
|
||||||
synapse/storage/engines,
|
synapse/storage/engines,
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
import abc
|
import abc
|
||||||
import os
|
import os
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from typing import Dict, Optional, Type
|
from typing import Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
|
||||||
# be here
|
# be here
|
||||||
before = DictProperty("before") # type: str
|
before = DictProperty("before") # type: str
|
||||||
after = DictProperty("after") # type: str
|
after = DictProperty("after") # type: str
|
||||||
order = DictProperty("order") # type: int
|
order = DictProperty("order") # type: Tuple[int, int]
|
||||||
|
|
||||||
def get_dict(self) -> JsonDict:
|
def get_dict(self) -> JsonDict:
|
||||||
return dict(self._dict)
|
return dict(self._dict)
|
||||||
|
|
|
@ -604,6 +604,18 @@ class DatabasePool(object):
|
||||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def execute(
|
||||||
|
self, desc: str, decoder: Literal[None], query: str, *args: Any
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def execute(
|
||||||
|
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
|
||||||
|
) -> R:
|
||||||
|
...
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
desc: str,
|
desc: str,
|
||||||
|
@ -1088,6 +1100,28 @@ class DatabasePool(object):
|
||||||
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def simple_select_one_onecol(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
keyvalues: Dict[str, Any],
|
||||||
|
retcol: Iterable[str],
|
||||||
|
allow_none: Literal[False] = False,
|
||||||
|
desc: str = "simple_select_one_onecol",
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def simple_select_one_onecol(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
keyvalues: Dict[str, Any],
|
||||||
|
retcol: Iterable[str],
|
||||||
|
allow_none: Literal[True] = True,
|
||||||
|
desc: str = "simple_select_one_onecol",
|
||||||
|
) -> Optional[Any]:
|
||||||
|
...
|
||||||
|
|
||||||
async def simple_select_one_onecol(
|
async def simple_select_one_onecol(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
|
|
|
@ -39,7 +39,7 @@ what sort order was used:
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -54,9 +54,12 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import Collection, RoomStreamToken
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,7 +209,7 @@ def _make_generic_sql_bound(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
|
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||||
# have indices on all possible clauses). E.g. it may create
|
# have indices on all possible clauses). E.g. it may create
|
||||||
# "room_id == X AND room_id != X", which postgres doesn't optimise.
|
# "room_id == X AND room_id != X", which postgres doesn't optimise.
|
||||||
|
@ -264,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||||
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
|
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
@ -297,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_room_max_stream_ordering(self):
|
def get_room_max_stream_ordering(self) -> int:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_room_min_stream_ordering(self):
|
def get_room_min_stream_ordering(self) -> int:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def get_room_events_stream_for_rooms(
|
async def get_room_events_stream_for_rooms(
|
||||||
self,
|
self,
|
||||||
room_ids: Iterable[str],
|
room_ids: Collection[str],
|
||||||
from_key: str,
|
from_key: str,
|
||||||
to_key: str,
|
to_key: str,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
|
@ -360,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_rooms_that_changed(self, room_ids, from_key):
|
def get_rooms_that_changed(
|
||||||
|
self, room_ids: Collection[str], from_key: str
|
||||||
|
) -> Set[str]:
|
||||||
"""Given a list of rooms and a token, return rooms where there may have
|
"""Given a list of rooms and a token, return rooms where there may have
|
||||||
been changes.
|
been changes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_ids (list)
|
room_ids
|
||||||
from_key (str): The room_key portion of a StreamToken
|
from_key: The room_key portion of a StreamToken
|
||||||
"""
|
"""
|
||||||
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
return {
|
return {
|
||||||
room_id
|
room_id
|
||||||
for room_id in room_ids
|
for room_id in room_ids
|
||||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
if self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_room_events_stream_for_room(
|
async def get_room_events_stream_for_room(
|
||||||
|
@ -444,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
return ret, key
|
return ret, key
|
||||||
|
|
||||||
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
async def get_membership_changes_for_user(
|
||||||
|
self, user_id: str, from_key: str, to_key: str
|
||||||
|
) -> List[EventBase]:
|
||||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||||
|
|
||||||
|
@ -661,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
return row[0][0] if row else 0
|
return row[0][0] if row else 0
|
||||||
|
|
||||||
def _get_max_topological_txn(self, txn, room_id):
|
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
|
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
|
||||||
(room_id,),
|
(room_id,),
|
||||||
|
@ -734,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
def _get_events_around_txn(
|
def _get_events_around_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
before_limit: int,
|
before_limit: int,
|
||||||
|
@ -762,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
retcols=["stream_ordering", "topological_ordering"],
|
retcols=["stream_ordering", "topological_ordering"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This cannot happen as `allow_none=False`.
|
||||||
|
assert results is not None
|
||||||
|
|
||||||
# Paginating backwards includes the event at the token, but paginating
|
# Paginating backwards includes the event at the token, but paginating
|
||||||
# forward doesn't.
|
# forward doesn't.
|
||||||
before_token = RoomStreamToken(
|
before_token = RoomStreamToken(
|
||||||
|
@ -871,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
desc="update_federation_out_pos",
|
desc="update_federation_out_pos",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _reset_federation_positions_txn(self, txn) -> None:
|
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
|
||||||
"""Fiddles with the `federation_stream_position` table to make it match
|
"""Fiddles with the `federation_stream_position` table to make it match
|
||||||
the configured federation sender instances during start up.
|
the configured federation sender instances during start up.
|
||||||
"""
|
"""
|
||||||
|
@ -910,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
GROUP BY type
|
GROUP BY type
|
||||||
"""
|
"""
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
min_positions = dict(txn) # Map from type -> min position
|
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
|
||||||
|
|
||||||
# Ensure we do actually have some values here
|
# Ensure we do actually have some values here
|
||||||
assert set(min_positions) == {"federation", "events"}
|
assert set(min_positions) == {"federation", "events"}
|
||||||
|
@ -937,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||||
|
|
||||||
def _paginate_room_events_txn(
|
def _paginate_room_events_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
from_token: RoomStreamToken,
|
from_token: RoomStreamToken,
|
||||||
to_token: Optional[RoomStreamToken] = None,
|
to_token: Optional[RoomStreamToken] = None,
|
||||||
|
|
Loading…
Reference in a new issue