Annotate synapse.storage.util (#10892)

Also mark `synapse.streams` as having has no untyped defs

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
David Robertson 2021-10-08 15:25:16 +01:00 committed by GitHub
parent 797ee7812d
commit 51a5da74cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 124 additions and 65 deletions

1
changelog.d/10892.misc Normal file
View file

@ -0,0 +1 @@
Add further type hints to `synapse.storage.util`.

View file

@ -105,6 +105,12 @@ disallow_untyped_defs = True
[mypy-synapse.state.*] [mypy-synapse.state.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.util.*]
disallow_untyped_defs = True
[mypy-synapse.streams.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue] [mypy-synapse.util.batching_queue]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from synapse.storage.types import Connection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker: class SlavedIdTracker:
def __init__( def __init__(
self, self,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
table: str, table: str,
column: str, column: str,
extra_tables: Optional[List[Tuple[str, str]]] = None, extra_tables: Optional[List[Tuple[str, str]]] = None,

View file

@ -15,9 +15,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.types import Connection
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -27,7 +26,12 @@ if TYPE_CHECKING:
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( # type: ignore self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]

View file

@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
from synapse.push import PusherConfig, ThrottleParams from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore): class PusherWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator( self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]

View file

@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo from synapse.types import UserID, UserInfo
@ -1775,7 +1775,12 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = ( self._ignore_unknown_session_error = (

View file

@ -16,42 +16,62 @@ import logging
import threading import threading
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from types import TracebackType
from typing import (
AsyncContextManager,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import attr import attr
from sortedcontainers import SortedSet from sortedcontainers import SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator from synapse.storage.util.sequence import PostgresSequenceGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T")
class IdGenerator: class IdGenerator:
def __init__(self, db_conn, table, column): def __init__(
self,
db_conn: LoggingDatabaseConnection,
table: str,
column: str,
):
self._lock = threading.Lock() self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column) self._next_id = _load_current_id(db_conn, table, column)
def get_next(self): def get_next(self) -> int:
with self._lock: with self._lock:
self._next_id += 1 self._next_id += 1
return self._next_id return self._next_id
def _load_current_id(db_conn, table, column, step=1): def _load_current_id(
""" db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
) -> int:
Args:
db_conn (object):
table (str):
column (str):
step (int):
Returns:
int
"""
# debug logging for https://github.com/matrix-org/synapse/issues/7968 # debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column) logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id") cur = db_conn.cursor(txn_name="_load_current_id")
@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else: else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
(val,) = cur.fetchone() result = cur.fetchone()
assert result is not None
(val,) = result
cur.close() cur.close()
current_id = int(val) if val else step current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step) return (max if step > 0 else min)(current_id, step)
@ -93,16 +115,16 @@ class StreamIdGenerator:
def __init__( def __init__(
self, self,
db_conn, db_conn: LoggingDatabaseConnection,
table, table: str,
column, column: str,
extra_tables: Iterable[Tuple[str, str]] = (), extra_tables: Iterable[Tuple[str, str]] = (),
step=1, step: int = 1,
): ) -> None:
assert step != 0 assert step != 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._step = step self._step: int = step
self._current = _load_current_id(db_conn, table, column, step) self._current: int = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: for table, column in extra_tables:
self._current = (max if step > 0 else min)( self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step) self._current, _load_current_id(db_conn, table, column, step)
@ -115,7 +137,7 @@ class StreamIdGenerator:
# The key and values are the same, but we never look at the values. # The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict() self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def get_next(self): def get_next(self) -> AsyncContextManager[int]:
""" """
Usage: Usage:
async with stream_id_gen.get_next() as stream_id: async with stream_id_gen.get_next() as stream_id:
@ -128,7 +150,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id self._unfinished_ids[next_id] = next_id
@contextmanager @contextmanager
def manager(): def manager() -> Generator[int, None, None]:
try: try:
yield next_id yield next_id
finally: finally:
@ -137,7 +159,7 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n): def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
""" """
Usage: Usage:
async with stream_id_gen.get_next(n) as stream_ids: async with stream_id_gen.get_next(n) as stream_ids:
@ -155,7 +177,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id self._unfinished_ids[next_id] = next_id
@contextmanager @contextmanager
def manager(): def manager() -> Generator[Sequence[int], None, None]:
try: try:
yield next_ids yield next_ids
finally: finally:
@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
def __init__( def __init__(
self, self,
db_conn, db_conn: LoggingDatabaseConnection,
db: DatabasePool, db: DatabasePool,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
sequence_name: str, sequence_name: str,
writers: List[str], writers: List[str],
positive: bool = True, positive: bool = True,
): ) -> None:
self._db = db self._db = db
self._stream_name = stream_name self._stream_name = stream_name
self._instance_name = instance_name self._instance_name = instance_name
@ -285,9 +307,9 @@ class MultiWriterIdGenerator:
def _load_current_ids( def _load_current_ids(
self, self,
db_conn, db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]], tables: List[Tuple[str, str, str]],
): ) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids") cur = db_conn.cursor(txn_name="_load_current_ids")
# Load the current positions of all writers for the stream. # Load the current positions of all writers for the stream.
@ -335,7 +357,9 @@ class MultiWriterIdGenerator:
"agg": "MAX" if self._positive else "-MIN", "agg": "MAX" if self._positive else "-MIN",
} }
cur.execute(sql) cur.execute(sql)
(stream_id,) = cur.fetchone() result = cur.fetchone()
assert result is not None
(stream_id,) = result
max_stream_id = max(max_stream_id, stream_id) max_stream_id = max(max_stream_id, stream_id)
@ -354,7 +378,7 @@ class MultiWriterIdGenerator:
self._persisted_upto_position = min_stream_id self._persisted_upto_position = min_stream_id
rows = [] rows: List[Tuple[str, int]] = []
for table, instance_column, id_column in tables: for table, instance_column, id_column in tables:
sql = """ sql = """
SELECT %(instance)s, %(id)s FROM %(table)s SELECT %(instance)s, %(id)s FROM %(table)s
@ -367,7 +391,8 @@ class MultiWriterIdGenerator:
} }
cur.execute(sql, (min_stream_id * self._return_factor,)) cur.execute(sql, (min_stream_id * self._return_factor,))
rows.extend(cur) # Cast safety: this corresponds to the types returned by the query above.
rows.extend(cast(Iterable[Tuple[str, int]], cur))
# Sort so that we handle rows in order for each instance. # Sort so that we handle rows in order for each instance.
rows.sort() rows.sort()
@ -385,13 +410,13 @@ class MultiWriterIdGenerator:
cur.close() cur.close()
def _load_next_id_txn(self, txn) -> int: def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn) return self._sequence_gen.get_next_id_txn(txn)
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n) return self._sequence_gen.get_next_mult_txn(txn, n)
def get_next(self): def get_next(self) -> AsyncContextManager[int]:
""" """
Usage: Usage:
async with stream_id_gen.get_next() as stream_id: async with stream_id_gen.get_next() as stream_id:
@ -403,9 +428,12 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers: if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer") raise Exception("Tried to allocate stream ID on non-writer")
return _MultiWriterCtxManager(self) # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int): def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
""" """
Usage: Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids: async with stream_id_gen.get_next_mult(5) as stream_ids:
@ -417,9 +445,10 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers: if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer") raise Exception("Tried to allocate stream ID on non-writer")
return _MultiWriterCtxManager(self, n) # Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
def get_next_txn(self, txn: LoggingTransaction): def get_next_txn(self, txn: LoggingTransaction) -> int:
""" """
Usage: Usage:
@ -457,7 +486,7 @@ class MultiWriterIdGenerator:
return self._return_factor * next_id return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int): def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the """The ID has finished being processed so we should advance the
current position if possible. current position if possible.
""" """
@ -534,7 +563,7 @@ class MultiWriterIdGenerator:
for name, i in self._current_positions.items() for name, i in self._current_positions.items()
} }
def advance(self, instance_name: str, new_id: int): def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater """Advance the position of the named writer to the given ID, if greater
than existing entry. than existing entry.
""" """
@ -560,7 +589,7 @@ class MultiWriterIdGenerator:
with self._lock: with self._lock:
return self._return_factor * self._persisted_upto_position return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int): def _add_persisted_position(self, new_id: int) -> None:
"""Record that we have persisted a position. """Record that we have persisted a position.
This is used to keep the `_current_positions` up to date. This is used to keep the `_current_positions` up to date.
@ -606,7 +635,7 @@ class MultiWriterIdGenerator:
# do. # do.
break break
def _update_stream_positions_table_txn(self, txn: Cursor): def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
"""Update the `stream_positions` table with newly persisted position.""" """Update the `stream_positions` table with newly persisted position."""
if not self._writers: if not self._writers:
@ -628,20 +657,25 @@ class MultiWriterIdGenerator:
txn.execute(sql, (self._stream_name, self._instance_name, pos)) txn.execute(sql, (self._stream_name, self._instance_name, pos))
@attr.s(slots=True) @attr.s(frozen=True, auto_attribs=True)
class _AsyncCtxManagerWrapper: class _AsyncCtxManagerWrapper(Generic[T]):
"""Helper class to convert a plain context manager to an async one. """Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface This is mainly useful if you have a plain context manager but the interface
requires an async one. requires an async one.
""" """
inner = attr.ib() inner: ContextManager[T]
async def __aenter__(self): async def __aenter__(self) -> T:
return self.inner.__enter__() return self.inner.__enter__()
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc, tb) return self.inner.__exit__(exc_type, exc, tb)
@ -671,7 +705,12 @@ class _MultiWriterCtxManager:
else: else:
return [i * self.id_gen._return_factor for i in self.stream_ids] return [i * self.id_gen._return_factor for i in self.stream_ids]
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
for i in self.stream_ids: for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i) self.id_gen._mark_id_as_finished(i)

View file

@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
id_column: str, id_column: str,
stream_name: Optional[str] = None, stream_name: Optional[str] = None,
positive: bool = True, positive: bool = True,
): ) -> None:
"""Should be called during start up to test that the current value of """Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table. the sequence is greater than or equal to the maximum ID in the table.
@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
id_column: str, id_column: str,
stream_name: Optional[str] = None, stream_name: Optional[str] = None,
positive: bool = True, positive: bool = True,
): ) -> None:
"""See SequenceGenerator.check_consistency for docstring.""" """See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency") txn = db_conn.cursor(txn_name="sequence.check_consistency")
@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
id_column: str, id_column: str,
stream_name: Optional[str] = None, stream_name: Optional[str] = None,
positive: bool = True, positive: bool = True,
): ) -> None:
# There is nothing to do for in memory sequences # There is nothing to do for in memory sequences
pass pass