0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-20 18:03:59 +01:00

Add some type annotations in synapse.storage (#6987)

I cracked, and added some type definitions in synapse.storage.
This commit is contained in:
Richard van der Hoff 2020-02-27 11:53:40 +00:00 committed by GitHub
parent 3e99528f2b
commit 132b673dbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 270 additions and 84 deletions

1
changelog.d/6987.misc Normal file
View file

@ -0,0 +1 @@
Add some type annotations to the database storage classes.

View file

@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import time
from typing import Iterable, Tuple
from time import monotonic as monotonic_time
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
try:
# on python 3, use time.monotonic, since time.clock can go backwards
from time import monotonic as monotonic_time
except ImportError:
# ... but python 2 doesn't have it
from time import clock as monotonic_time
logger = logging.getLogger(__name__)
try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
@ -90,7 +80,9 @@ def make_pool(
)
def make_conn(db_config: DatabaseConnectionConfig, engine):
def make_conn(
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> Connection:
"""Make a new connection to the database and return it.
Returns:
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn
class LoggingTransaction(object):
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.
Args:
txn: The database transcation object to wrap.
name (str): The name of this transactions for logging.
database_engine (Sqlite3Engine|PostgresEngine)
after_callbacks(list|None): A list that callbacks will be appended to
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
exception_callbacks(list|None): A list that callbacks will be appended
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
@ -135,46 +134,67 @@ class LoggingTransaction(object):
]
def __init__(
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
self,
txn: Cursor,
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
object.__setattr__(self, "exception_callbacks", exception_callbacks)
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(self, callback, *args, **kwargs):
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(self, callback, *args, **kwargs):
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
def __setattr__(self, name, value):
setattr(self.txn, name, value)
def fetchone(self) -> Tuple:
return self.txn.fetchone()
def __iter__(self):
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@property
def rowcount(self) -> int:
return self.txn.rowcount
@property
def description(self) -> Any:
return self.txn.description
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
def execute(self, sql, *args):
def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
@ -207,6 +227,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
def close(self):
self.txn.close()
class PerformanceCounters(object):
def __init__(self):
@ -251,7 +274,9 @@ class Database(object):
_TXN_ID = 0
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
def __init__(
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
@ -259,9 +284,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self)
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0.0
self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
@ -463,23 +488,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function
Arguments:
desc (str): description of the transaction, for logging and metrics
func (func): callback function, which will be called with a
desc: description of the transaction, for logging and metrics
func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
after_callbacks = []
exception_callbacks = []
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
@ -505,15 +530,15 @@ class Database(object):
return result
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
func (func): callback function, which will be called with a
func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
args: positional args to pass to `func`
kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
@ -800,7 +825,7 @@ class Database(object):
return False
# We didn't find any existing rows, so insert a new one
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@ -829,7 +854,7 @@ class Database(object):
Returns:
None
"""
allvalues = {}
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
@ -916,7 +941,7 @@ class Database(object):
Returns:
None
"""
allnames = []
allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)
@ -1100,7 +1125,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
results = []
results = [] # type: List[Dict[str, Any]]
if not iterable:
return results
@ -1439,7 +1464,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
arg_list = []
arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())

View file

@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import platform
from ._base import IncorrectDatabaseSetup
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config):
def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
if name == "sqlite3":
import sqlite3
return Sqlite3Engine(sqlite3, database_config)
if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
if name == "psycopg2" and platform.python_implementation() == "PyPy":
name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
if platform.python_implementation() == "PyPy":
import psycopg2cffi as psycopg2 # type: ignore
else:
import psycopg2 # type: ignore
return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]

View file

@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Generic, TypeVar
from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
ConnectionType = TypeVar("ConnectionType", bound=Connection)
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def __init__(self, module, database_config: dict):
self.module = module
@property
@abc.abstractmethod
def single_threaded(self) -> bool:
...
@property
@abc.abstractmethod
def can_native_upsert(self) -> bool:
"""
Do we support native UPSERTs?
"""
...
@property
@abc.abstractmethod
def supports_tuple_comparison(self) -> bool:
"""
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
"""
...
@property
@abc.abstractmethod
def supports_using_any_list(self) -> bool:
"""
Do we support using `a = ANY(?)` and passing a list
"""
...
@abc.abstractmethod
def check_database(
self, db_conn: ConnectionType, allow_outdated_version: bool = False
) -> None:
...
@abc.abstractmethod
def check_new_database(self, txn) -> None:
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
...
@abc.abstractmethod
def convert_param_style(self, sql: str) -> str:
...
@abc.abstractmethod
def on_new_connection(self, db_conn: ConnectionType) -> None:
...
@abc.abstractmethod
def is_deadlock(self, error: Exception) -> bool:
...
@abc.abstractmethod
def is_connection_closed(self, conn: ConnectionType) -> bool:
...
@abc.abstractmethod
def lock_table(self, txn, table: str) -> None:
...
@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...
@property
@abc.abstractmethod
def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'
"""
...

View file

@ -15,16 +15,14 @@
import logging
from ._base import IncorrectDatabaseSetup
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__)
class PostgresEngine(object):
single_threaded = False
class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
self.module = database_module
super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
@property
def single_threaded(self) -> bool:
return False
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and

View file

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import struct
import threading
from synapse.storage.engines import BaseDatabaseEngine
class Sqlite3Engine(object):
single_threaded = True
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
def __init__(self, database_module, database_config):
self.module = database_module
super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
@property
def single_threaded(self) -> bool:
return True
@property
def can_native_upsert(self):
"""
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database

65
synapse/storage/types.py Normal file
View file

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
...
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
...
def fetchall(self) -> List[Tuple]:
...
def fetchone(self) -> Tuple:
...
@property
def description(self) -> Any:
return None
@property
def rowcount(self) -> int:
return 0
def __iter__(self) -> Iterator[Tuple]:
...
def close(self) -> None:
...
class Connection(Protocol):
def cursor(self) -> Cursor:
...
def close(self) -> None:
...
def commit(self) -> None:
...
def rollback(self, *args, **kwargs) -> None:
...

View file

@ -168,7 +168,6 @@ commands=
coverage html
[testenv:mypy]
basepython = python3.7
skip_install = True
deps =
{[base]deps}
@ -179,7 +178,8 @@ env =
extras = all
commands = mypy \
synapse/api \
synapse/config/ \
synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \
synapse/federation/sender \
synapse/federation/transport \
@ -192,6 +192,7 @@ commands = mypy \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/storage/database.py \
synapse/streams
# To find all folders that pass mypy you run: