Add additional type hints to the storage module. (#8980)

This commit is contained in:
Patrick Cloke 2020-12-30 08:09:53 -05:00 committed by GitHub
parent b8591899ab
commit 637282bb50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 224 additions and 148 deletions

1
changelog.d/8980.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the base storage code.

View file

@ -70,6 +70,9 @@ files =
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/__init__.py,
synapse/storage/_base.py,
synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/pusher.py,
@ -78,8 +81,15 @@ files =
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/keys.py,
synapse/storage/persist_events.py, synapse/storage/persist_events.py,
synapse/storage/prepare_database.py,
synapse/storage/purge_events.py,
synapse/storage/push_rule.py,
synapse/storage/relations.py,
synapse/storage/roommember.py,
synapse/storage/state.py, synapse/storage/state.py,
synapse/storage/types.py,
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,

View file

@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
member_event_id: str, member_event_id: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = await self.state_store.get_state_for_event(member_event_id)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:

View file

@ -554,7 +554,7 @@ class SyncHandler:
event.event_id, state_filter=state_filter event.event_id, state_filter=state_filter
) )
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id state_ids[(event.type, event.state_key)] = event.event_id
return state_ids return state_ids

View file

@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`. stored in `synapse.storage.schema`.
""" """
from typing import TYPE_CHECKING
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage from synapse.storage.state import StateGroupStorage
__all__ = ["DataStores", "DataStore"] if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
__all__ = ["Databases", "DataStore"]
class Storage: class Storage:
"""The high level interfaces for talking to various storage layers. """The high level interfaces for talking to various storage layers.
""" """
def __init__(self, hs, stores: Databases): def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to # We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level # rewrite all the existing code to split it into high vs low level
# interfaces. # interfaces.

View file

@ -17,14 +17,18 @@
import logging import logging
import random import random
from abc import ABCMeta from abc import ABCMeta
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id from synapse.storage.types import Connection
from synapse.types import Collection, StreamToken, get_domain_from_id
from synapse.util import json_decoder from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database). per data store (and not one per physical database).
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine
self.db_pool = database self.db_pool = database
self.rand = random.SystemRandom() self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: StreamToken,
rows: Iterable[Any],
) -> None:
pass pass
def _invalidate_state_caches(self, room_id, members_changed): def _invalidate_state_caches(
self, room_id: str, members_changed: Iterable[str]
) -> None:
"""Invalidates caches that are based on the current state, but does """Invalidates caches that are based on the current state, but does
not stream invalidations down replication. not stream invalidations down replication.
Args: Args:
room_id (str): Room where state changed room_id: Room where state changed
members_changed (iterable[str]): The user_ids of members that have members_changed: The user_ids of members that have changed
changed
""" """
for host in {get_domain_from_id(u) for u in members_changed}: for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache( def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]] self, cache_name: str, key: Optional[Collection[Any]]
): ) -> None:
"""Attempts to invalidate the cache of the given name, ignoring if the """Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers, cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache. where they may not have the cache.
@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
cache.invalidate(tuple(key)) cache.invalidate(tuple(key))
def db_to_json(db_content): def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
""" """
Take some data from a database row and return a JSON-decoded object. Take some data from a database row and return a JSON-decoded object.
Args: Args:
db_content (memoryview|buffer|bytes|bytearray|unicode) db_content: The JSON-encoded contents from the database.
Returns:
The object decoded from JSON.
""" """
# psycopg2 on Python 3 returns memoryview objects, which we need to # psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode # cast to bytes to decode

View file

@ -12,29 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from . import engines from . import engines
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance: class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items""" """Tracks the how long a background update is taking to update its items"""
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name = name
self.total_item_count = 0 self.total_item_count = 0
self.total_duration_ms = 0 self.total_duration_ms = 0.0
self.avg_item_count = 0 self.avg_item_count = 0.0
self.avg_duration_ms = 0 self.avg_duration_ms = 0.0
def update(self, item_count, duration_ms): def update(self, item_count: int, duration_ms: float) -> None:
"""Update the stats after doing an update""" """Update the stats after doing an update"""
self.total_item_count += item_count self.total_item_count += item_count
self.total_duration_ms += duration_ms self.total_duration_ms += duration_ms
@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
def average_items_per_ms(self): def average_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update. """An estimate of how long it takes to do a single update.
Returns: Returns:
A duration in ms as a float A duration in ms as a float
@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
# changes in how long the update process takes. # changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms) return float(self.avg_item_count) / float(self.avg_duration_ms)
def total_items_per_ms(self): def total_items_per_ms(self) -> Optional[float]:
"""An estimate of how long it takes to do a single update. """An estimate of how long it takes to do a single update.
Returns: Returns:
A duration in ms as a float A duration in ms as a float
@ -83,21 +88,25 @@ class BackgroundUpdater:
BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100 BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs, database): def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.db_pool = database self.db_pool = database
# if a background update is currently running, its name. # if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str] self._current_background_update = None # type: Optional[str]
self._background_update_performance = {} self._background_update_performance = (
self._background_update_handlers = {} {}
) # type: Dict[str, BackgroundUpdatePerformance]
self._background_update_handlers = (
{}
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
self._all_done = False self._all_done = False
def start_doing_background_updates(self): def start_doing_background_updates(self) -> None:
run_as_background_process("background_updates", self.run_background_updates) run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep=True): async def run_background_updates(self, sleep: bool = True) -> None:
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while True: while True:
if sleep: if sleep:
@ -148,7 +157,7 @@ class BackgroundUpdater:
return False return False
async def has_completed_background_update(self, update_name) -> bool: async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running. """Check if the given background update has finished running.
""" """
if self._all_done: if self._all_done:
@ -173,8 +182,7 @@ class BackgroundUpdater:
Returns once some amount of work is done. Returns once some amount of work is done.
Args: Args:
desired_duration_ms(float): How long we want to spend desired_duration_ms: How long we want to spend updating.
updating.
Returns: Returns:
True if we have finished running all the background updates, otherwise False True if we have finished running all the background updates, otherwise False
""" """
@ -220,6 +228,7 @@ class BackgroundUpdater:
return False return False
async def _do_background_update(self, desired_duration_ms: float) -> int: async def _do_background_update(self, desired_duration_ms: float) -> int:
assert self._current_background_update is not None
update_name = self._current_background_update update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name) logger.info("Starting update batch on background update '%s'", update_name)
@ -273,7 +282,11 @@ class BackgroundUpdater:
return len(self._background_update_performance) return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler): def register_background_update_handler(
self,
update_name: str,
update_handler: Callable[[JsonDict, int], Awaitable[int]],
):
"""Register a handler for doing a background update. """Register a handler for doing a background update.
The handler should take two arguments: The handler should take two arguments:
@ -287,12 +300,12 @@ class BackgroundUpdater:
The handler is responsible for updating the progress of the update. The handler is responsible for updating the progress of the update.
Args: Args:
update_name(str): The name of the update that this code handles. update_name: The name of the update that this code handles.
update_handler(function): The function that does the update. update_handler: The function that does the update.
""" """
self._background_update_handlers[update_name] = update_handler self._background_update_handlers[update_name] = update_handler
def register_noop_background_update(self, update_name): def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update. """Register a noop handler for a background update.
This is useful when we previously did a background update, but no This is useful when we previously did a background update, but no
@ -302,10 +315,10 @@ class BackgroundUpdater:
also be called to clear the update. also be called to clear the update.
Args: Args:
update_name (str): Name of update update_name: Name of update
""" """
async def noop_update(progress, batch_size): async def noop_update(progress: JsonDict, batch_size: int) -> int:
await self._end_background_update(update_name) await self._end_background_update(update_name)
return 1 return 1
@ -313,14 +326,14 @@ class BackgroundUpdater:
def register_background_index_update( def register_background_index_update(
self, self,
update_name, update_name: str,
index_name, index_name: str,
table, table: str,
columns, columns: Iterable[str],
where_clause=None, where_clause: Optional[str] = None,
unique=False, unique: bool = False,
psql_only=False, psql_only: bool = False,
): ) -> None:
"""Helper for store classes to do a background index addition """Helper for store classes to do a background index addition
To use: To use:
@ -332,19 +345,19 @@ class BackgroundUpdater:
2. In the Store constructor, call this method 2. In the Store constructor, call this method
Args: Args:
update_name (str): update_name to register for update_name: update_name to register for
index_name (str): name of index to add index_name: name of index to add
table (str): table to add index to table: table to add index to
columns (list[str]): columns/expressions to include in index columns: columns/expressions to include in index
unique (bool): true to make a UNIQUE index unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables) for virtual sqlite tables)
""" """
def create_index_psql(conn): def create_index_psql(conn: Connection) -> None:
conn.rollback() conn.rollback()
# postgres insists on autocommit for the index # postgres insists on autocommit for the index
conn.set_session(autocommit=True) conn.set_session(autocommit=True) # type: ignore
try: try:
c = conn.cursor() c = conn.cursor()
@ -371,9 +384,9 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql) logger.debug("[SQL] %s", sql)
c.execute(sql) c.execute(sql)
finally: finally:
conn.set_session(autocommit=False) conn.set_session(autocommit=False) # type: ignore
def create_index_sqlite(conn): def create_index_sqlite(conn: Connection) -> None:
# Sqlite doesn't support concurrent creation of indexes. # Sqlite doesn't support concurrent creation of indexes.
# #
# We don't use partial indices on SQLite as it wasn't introduced # We don't use partial indices on SQLite as it wasn't introduced
@ -399,7 +412,7 @@ class BackgroundUpdater:
c.execute(sql) c.execute(sql)
if isinstance(self.db_pool.engine, engines.PostgresEngine): if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql runner = create_index_psql # type: Optional[Callable[[Connection], None]]
elif psql_only: elif psql_only:
runner = None runner = None
else: else:
@ -433,7 +446,9 @@ class BackgroundUpdater:
"background_updates", keyvalues={"update_name": update_name} "background_updates", keyvalues={"update_name": update_name}
) )
async def _background_update_progress(self, update_name: str, progress: dict): async def _background_update_progress(
self, update_name: str, progress: dict
) -> None:
"""Update the progress of a background update """Update the progress of a background update
Args: Args:
@ -441,20 +456,22 @@ class BackgroundUpdater:
progress: The progress of the update. progress: The progress of the update.
""" """
return await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"background_update_progress", "background_update_progress",
self._background_update_progress_txn, self._background_update_progress_txn,
update_name, update_name,
progress, progress,
) )
def _background_update_progress_txn(self, txn, update_name, progress): def _background_update_progress_txn(
self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
) -> None:
"""Update the progress of a background update """Update the progress of a background update
Args: Args:
txn(cursor): The transaction. txn: The transaction.
update_name(str): The name of the background update task update_name: The name of the background update task
progress(dict): The progress of the update. progress: The progress of the update.
""" """
progress_json = json_encoder.encode(progress) progress_json = json_encoder.encode(progress)

View file

@ -17,11 +17,12 @@
import logging import logging
import attr import attr
from signedjson.types import VerifyKey
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class FetchKeyResult: class FetchKeyResult:
verify_key = attr.ib() # VerifyKey: the key itself verify_key = attr.ib(type=VerifyKey) # the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for valid_until_ts = attr.ib(type=int) # how long we can use this key for

View file

@ -18,9 +18,10 @@ import logging
import os import os
import re import re
from collections import Counter from collections import Counter
from typing import Optional, TextIO from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr import attr
from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
@ -70,7 +71,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine, database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig], config: Optional[HomeServerConfig],
databases: Collection[str] = ["main", "state"], databases: Collection[str] = ("main", "state"),
): ):
"""Prepares a physical database for usage. Will either create all necessary tables """Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
@ -155,7 +156,9 @@ def prepare_database(
raise raise
def _setup_new_database(cur, database_engine, databases): def _setup_new_database(
cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and """Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data then applying any necessary deltas, including schemas from the given data
stores. stores.
@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified. folder as well those in the data stores specified.
Args: Args:
cur (Cursor): a database cursor cur: a database cursor
database_engine (DatabaseEngine) database_engine
databases (list[str]): The names of the databases to instantiate databases: The names of the databases to instantiate on the given physical database.
on the given physical database.
""" """
# We're about to set up a brand new database so we check that its # We're about to set up a brand new database so we check that its
@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur) database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas") current_dir = os.path.join(dir_path, "schema", "full_schemas")
directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have # First we find the highest full schema version we have
valid_versions = [] valid_versions = []
for filename in directory_entries: for filename in os.listdir(current_dir):
try: try:
ver = int(filename) ver = int(filename)
except ValueError: except ValueError:
@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases for database in databases
) )
directory_entries = [] directory_entries = [] # type: List[_DirectoryListing]
for directory in directories: for directory in directories:
directory_entries.extend( directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name)) _DirectoryListing(file_name, os.path.join(directory, file_name))
@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database( def _upgrade_existing_database(
cur, cur: Cursor,
current_version, current_version: int,
applied_delta_files, applied_delta_files: List[str],
upgraded, upgraded: bool,
database_engine, database_engine: BaseDatabaseEngine,
config, config: Optional[HomeServerConfig],
databases, databases: Collection[str],
is_empty=False, is_empty: bool = False,
): ) -> None:
"""Upgrades an existing physical database. """Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -323,21 +324,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version. for a version before applying those in the next version.
Args: Args:
cur (Cursor) cur
current_version (int): The current version of the schema. current_version: The current version of the schema.
applied_delta_files (list): A list of deltas that have already been applied_delta_files: A list of deltas that have already been applied.
applied. upgraded: Whether the current version was generated by having
upgraded (bool): Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files. the current_version wasn't generated by applying those delta files.
database_engine (DatabaseEngine) database_engine
config (synapse.config.homeserver.HomeServerConfig|None): config:
None if we are initialising a blank database, otherwise the application None if we are initialising a blank database, otherwise the application
config config
databases (list[str]): The names of the databases to instantiate databases: The names of the databases to instantiate
on the given physical database. on the given physical database.
is_empty (bool): Is this a blank database? I.e. do we need to run the is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts. upgrade portions of the delta scripts.
""" """
if is_empty: if is_empty:
@ -358,6 +358,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases: if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade from synapse.storage.databases.main import check_database_before_upgrade
assert config is not None
check_database_before_upgrade(cur, database_engine, config) check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version start_ver = current_version
@ -388,10 +389,10 @@ def _upgrade_existing_database(
) )
# Used to check if we have any duplicate file names # Used to check if we have any duplicate file names
file_name_counter = Counter() file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest. # Now find which directories have anything of interest.
directory_entries = [] directory_entries = [] # type: List[_DirectoryListing]
for directory in directories: for directory in directories:
logger.debug("Looking for schema deltas in %s", directory) logger.debug("Looking for schema deltas in %s", directory)
try: try:
@ -445,11 +446,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name) module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file: with open(absolute_path) as python_file:
module = imp.load_source(module_name, absolute_path, python_file) module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path) logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) module.run_create(cur, database_engine) # type: ignore
if not is_empty: if not is_empty:
module.run_upgrade(cur, database_engine, config=config) module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__": elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've # Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package # disabled their generation; e.g. from distribution package
@ -497,14 +498,15 @@ def _upgrade_existing_database(
logger.info("Schema now up to date") logger.info("Schema now up to date")
def _apply_module_schemas(txn, database_engine, config): def _apply_module_schemas(
txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
) -> None:
"""Apply the module schemas for the dynamic modules, if any """Apply the module schemas for the dynamic modules, if any
Args: Args:
cur: database cursor cur: database cursor
database_engine: synapse database engine class database_engine:
config (synapse.config.homeserver.HomeServerConfig): config: application config
application config
""" """
for (mod, _config) in config.password_providers: for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"): if not hasattr(mod, "get_db_schema_files"):
@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
) )
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): def _apply_module_schema_files(
cur: Cursor,
database_engine: BaseDatabaseEngine,
modname: str,
names_and_streams: Iterable[Tuple[str, TextIO]],
) -> None:
"""Apply the module schemas for a single module """Apply the module schemas for a single module
Args: Args:
cur: database cursor cur: database cursor
database_engine: synapse database engine class database_engine: synapse database engine class
modname (str): fully qualified name of the module modname: fully qualified name of the module
names_and_streams (Iterable[(str, file)]): the names and streams of names_and_streams: the names and streams of schemas to be applied
schemas to be applied
""" """
cur.execute( cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
) )
def get_statements(f): def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = "" statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment in_comment = False # If we're in a /* ... */ style comment
@ -594,17 +600,19 @@ def get_statements(f):
statement_buffer = statements[-1].strip() statement_buffer = statements[-1].strip()
def executescript(txn, schema_path): def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f: with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f) execute_statements_from_stream(txn, f)
def execute_statements_from_stream(cur: Cursor, f: TextIO): def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f): for statement in get_statements(f):
cur.execute(statement) cur.execute(statement)
def _get_or_create_schema_state(txn, database_engine): def _get_or_create_schema_state(
txn: Cursor, database_engine: BaseDatabaseEngine
) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables. # Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql") schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path) executescript(txn, schema_path)
@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()
current_version = int(row[0]) if row else None current_version = int(row[0]) if row else None
upgraded = bool(row[1]) if row else None
if current_version: if current_version:
txn.execute( txn.execute(
@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
(current_version,), (current_version,),
) )
applied_deltas = [d for d, in txn] applied_deltas = [d for d, in txn]
upgraded = bool(row[1])
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None
@ -634,5 +642,5 @@ class _DirectoryListing:
`file_name` attr is kept first. `file_name` attr is kept first.
""" """
file_name = attr.ib() file_name = attr.ib(type=str)
absolute_path = attr.ib() absolute_path = attr.ib(type=str)

View file

@ -15,7 +15,12 @@
import itertools import itertools
import logging import logging
from typing import Set from typing import TYPE_CHECKING, Set
from synapse.storage.databases import Databases
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,10 +29,10 @@ class PurgeEventsStorage:
"""High level interface for purging rooms and event history. """High level interface for purging rooms and event history.
""" """
def __init__(self, hs, stores): def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores self.stores = stores
async def purge_room(self, room_id: str): async def purge_room(self, room_id: str) -> None:
"""Deletes all record of a room """Deletes all record of a room
""" """

View file

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple
import attr import attr
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,18 +29,18 @@ class PaginationChunk:
"""Returned by relation pagination APIs. """Returned by relation pagination APIs.
Attributes: Attributes:
chunk (list): The rows returned by pagination chunk: The rows returned by pagination
next_batch (Any|None): Token to fetch next set of results with, if next_batch: Token to fetch next set of results with, if
None then there are no more results. None then there are no more results.
prev_batch (Any|None): Token to fetch previous set of results with, if prev_batch: Token to fetch previous set of results with, if
None then there are no previous results. None then there are no previous results.
""" """
chunk = attr.ib() chunk = attr.ib(type=List[JsonDict])
next_batch = attr.ib(default=None) next_batch = attr.ib(type=Optional[Any], default=None)
prev_batch = attr.ib(default=None) prev_batch = attr.ib(type=Optional[Any], default=None)
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk} d = {"chunk": self.chunk}
if self.next_batch: if self.next_batch:
@ -59,25 +61,25 @@ class RelationPaginationToken:
boundaries of the chunk as pagination tokens. boundaries of the chunk as pagination tokens.
Attributes: Attributes:
topological (int): The topological ordering of the boundary event topological: The topological ordering of the boundary event
stream (int): The stream ordering of the boundary event. stream: The stream ordering of the boundary event.
""" """
topological = attr.ib() topological = attr.ib(type=int)
stream = attr.ib() stream = attr.ib(type=int)
@staticmethod @staticmethod
def from_string(string): def from_string(string: str) -> "RelationPaginationToken":
try: try:
t, s = string.split("-") t, s = string.split("-")
return RelationPaginationToken(int(t), int(s)) return RelationPaginationToken(int(t), int(s))
except ValueError: except ValueError:
raise SynapseError(400, "Invalid token") raise SynapseError(400, "Invalid token")
def to_string(self): def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream) return "%d-%d" % (self.topological, self.stream)
def as_tuple(self): def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self) return attr.astuple(self)
@ -89,23 +91,23 @@ class AggregationPaginationToken:
aggregation groups, we can just use them as our pagination token. aggregation groups, we can just use them as our pagination token.
Attributes: Attributes:
count (int): The count of relations in the boundar group. count: The count of relations in the boundary group.
stream (int): The MAX stream ordering in the boundary group. stream: The MAX stream ordering in the boundary group.
""" """
count = attr.ib() count = attr.ib(type=int)
stream = attr.ib() stream = attr.ib(type=int)
@staticmethod @staticmethod
def from_string(string): def from_string(string: str) -> "AggregationPaginationToken":
try: try:
c, s = string.split("-") c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s)) return AggregationPaginationToken(int(c), int(s))
except ValueError: except ValueError:
raise SynapseError(400, "Invalid token") raise SynapseError(400, "Invalid token")
def to_string(self): def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream) return "%d-%d" % (self.count, self.stream)
def as_tuple(self): def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self) return attr.astuple(self)

View file

@ -12,9 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar from typing import (
TYPE_CHECKING,
Awaitable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import attr import attr
@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.storage.databases import Databases
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Used for generic functions below # Used for generic functions below
@ -330,10 +343,12 @@ class StateGroupStorage:
"""High level interface to fetching state for event. """High level interface to fetching state for event.
""" """
def __init__(self, hs, stores): def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores self.stores = stores
async def get_state_group_delta(self, state_group: int): async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
@ -341,8 +356,8 @@ class StateGroupStorage:
state_group: The state group used to retrieve state deltas. state_group: The state group used to retrieve state deltas.
Returns: Returns:
Tuple[Optional[int], Optional[StateMap[str]]]: A tuple of the previous group and a state map of the event IDs which
(prev_group, delta_ids) make up the delta between the old and new state groups.
""" """
return await self.stores.state.get_state_group_delta(state_group) return await self.stores.state.get_state_group_delta(state_group)
@ -436,7 +451,7 @@ class StateGroupStorage:
async def get_state_for_events( async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
): ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@ -472,7 +487,7 @@ class StateGroupStorage:
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
): ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves) of the state events (as opposed to the events themselves)
@ -500,7 +515,7 @@ class StateGroupStorage:
async def get_state_for_event( async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: StateFilter = StateFilter.all()
): ) -> StateMap[EventBase]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@ -516,7 +531,7 @@ class StateGroupStorage:
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all() self, event_id: str, state_filter: StateFilter = StateFilter.all()
): ) -> StateMap[str]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event