forked from MirrorHub/synapse
Add some type annotations in synapse.storage
(#6987)
I cracked, and added some type definitions in synapse.storage.
This commit is contained in:
parent
3e99528f2b
commit
132b673dbe
8 changed files with 270 additions and 84 deletions
1
changelog.d/6987.misc
Normal file
1
changelog.d/6987.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add some type annotations to the database storage classes.
|
|
@ -15,9 +15,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Tuple
|
from time import monotonic as monotonic_time
|
||||||
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from six import iteritems, iterkeys, itervalues
|
from six import iteritems, iterkeys, itervalues
|
||||||
from six.moves import intern, range
|
from six.moves import intern, range
|
||||||
|
@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
from synapse.logging.context import LoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.background_updates import BackgroundUpdater
|
from synapse.storage.background_updates import BackgroundUpdater
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
|
from synapse.storage.types import Connection, Cursor
|
||||||
from synapse.util.stringutils import exception_to_unicode
|
from synapse.util.stringutils import exception_to_unicode
|
||||||
|
|
||||||
# import a function which will return a monotonic time, in seconds
|
|
||||||
try:
|
|
||||||
# on python 3, use time.monotonic, since time.clock can go backwards
|
|
||||||
from time import monotonic as monotonic_time
|
|
||||||
except ImportError:
|
|
||||||
# ... but python 2 doesn't have it
|
|
||||||
from time import clock as monotonic_time
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
# python 3 does not have a maximum int value
|
||||||
MAX_TXN_ID = sys.maxint - 1
|
MAX_TXN_ID = 2 ** 63 - 1
|
||||||
except AttributeError:
|
|
||||||
# python 3 does not have a maximum int value
|
|
||||||
MAX_TXN_ID = 2 ** 63 - 1
|
|
||||||
|
|
||||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||||
|
@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
||||||
|
|
||||||
|
|
||||||
def make_pool(
|
def make_pool(
|
||||||
reactor, db_config: DatabaseConnectionConfig, engine
|
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||||
) -> adbapi.ConnectionPool:
|
) -> adbapi.ConnectionPool:
|
||||||
"""Get the connection pool for the database.
|
"""Get the connection pool for the database.
|
||||||
"""
|
"""
|
||||||
|
@ -90,7 +80,9 @@ def make_pool(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_conn(db_config: DatabaseConnectionConfig, engine):
|
def make_conn(
|
||||||
|
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||||
|
) -> Connection:
|
||||||
"""Make a new connection to the database and return it.
|
"""Make a new connection to the database and return it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
|
||||||
return db_conn
|
return db_conn
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||||
|
#
|
||||||
|
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
|
||||||
|
# that mypy sees the type but the runtime python doesn't.
|
||||||
|
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingTransaction:
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
method.
|
method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn: The database transcation object to wrap.
|
txn: The database transcation object to wrap.
|
||||||
name (str): The name of this transactions for logging.
|
name: The name of this transactions for logging.
|
||||||
database_engine (Sqlite3Engine|PostgresEngine)
|
database_engine
|
||||||
after_callbacks(list|None): A list that callbacks will be appended to
|
after_callbacks: A list that callbacks will be appended to
|
||||||
that have been added by `call_after` which should be run on
|
that have been added by `call_after` which should be run on
|
||||||
successful completion of the transaction. None indicates that no
|
successful completion of the transaction. None indicates that no
|
||||||
callbacks should be allowed to be scheduled to run.
|
callbacks should be allowed to be scheduled to run.
|
||||||
exception_callbacks(list|None): A list that callbacks will be appended
|
exception_callbacks: A list that callbacks will be appended
|
||||||
to that have been added by `call_on_exception` which should be run
|
to that have been added by `call_on_exception` which should be run
|
||||||
if transaction ends with an error. None indicates that no callbacks
|
if transaction ends with an error. None indicates that no callbacks
|
||||||
should be allowed to be scheduled to run.
|
should be allowed to be scheduled to run.
|
||||||
|
@ -135,46 +134,67 @@ class LoggingTransaction(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
|
self,
|
||||||
|
txn: Cursor,
|
||||||
|
name: str,
|
||||||
|
database_engine: BaseDatabaseEngine,
|
||||||
|
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||||
|
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||||
):
|
):
|
||||||
object.__setattr__(self, "txn", txn)
|
self.txn = txn
|
||||||
object.__setattr__(self, "name", name)
|
self.name = name
|
||||||
object.__setattr__(self, "database_engine", database_engine)
|
self.database_engine = database_engine
|
||||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
self.after_callbacks = after_callbacks
|
||||||
object.__setattr__(self, "exception_callbacks", exception_callbacks)
|
self.exception_callbacks = exception_callbacks
|
||||||
|
|
||||||
def call_after(self, callback, *args, **kwargs):
|
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||||
"""Call the given callback on the main twisted thread after the
|
"""Call the given callback on the main twisted thread after the
|
||||||
transaction has finished. Used to invalidate the caches on the
|
transaction has finished. Used to invalidate the caches on the
|
||||||
correct thread.
|
correct thread.
|
||||||
"""
|
"""
|
||||||
|
# if self.after_callbacks is None, that means that whatever constructed the
|
||||||
|
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||||
|
# is not the case.
|
||||||
|
assert self.after_callbacks is not None
|
||||||
self.after_callbacks.append((callback, args, kwargs))
|
self.after_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
def call_on_exception(self, callback, *args, **kwargs):
|
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
|
||||||
|
# if self.exception_callbacks is None, that means that whatever constructed the
|
||||||
|
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||||
|
# is not the case.
|
||||||
|
assert self.exception_callbacks is not None
|
||||||
self.exception_callbacks.append((callback, args, kwargs))
|
self.exception_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def fetchall(self) -> List[Tuple]:
|
||||||
return getattr(self.txn, name)
|
return self.txn.fetchall()
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def fetchone(self) -> Tuple:
|
||||||
setattr(self.txn, name, value)
|
return self.txn.fetchone()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[Tuple]:
|
||||||
return self.txn.__iter__()
|
return self.txn.__iter__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rowcount(self) -> int:
|
||||||
|
return self.txn.rowcount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> Any:
|
||||||
|
return self.txn.description
|
||||||
|
|
||||||
def execute_batch(self, sql, args):
|
def execute_batch(self, sql, args):
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
from psycopg2.extras import execute_batch
|
from psycopg2.extras import execute_batch # type: ignore
|
||||||
|
|
||||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||||
else:
|
else:
|
||||||
for val in args:
|
for val in args:
|
||||||
self.execute(sql, val)
|
self.execute(sql, val)
|
||||||
|
|
||||||
def execute(self, sql, *args):
|
def execute(self, sql: str, *args: Any):
|
||||||
self._do_execute(self.txn.execute, sql, *args)
|
self._do_execute(self.txn.execute, sql, *args)
|
||||||
|
|
||||||
def executemany(self, sql, *args):
|
def executemany(self, sql: str, *args: Any):
|
||||||
self._do_execute(self.txn.executemany, sql, *args)
|
self._do_execute(self.txn.executemany, sql, *args)
|
||||||
|
|
||||||
def _make_sql_one_line(self, sql):
|
def _make_sql_one_line(self, sql):
|
||||||
|
@ -207,6 +227,9 @@ class LoggingTransaction(object):
|
||||||
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
|
||||||
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
sql_query_timer.labels(sql.split()[0]).observe(secs)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.txn.close()
|
||||||
|
|
||||||
|
|
||||||
class PerformanceCounters(object):
|
class PerformanceCounters(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -251,7 +274,9 @@ class Database(object):
|
||||||
|
|
||||||
_TXN_ID = 0
|
_TXN_ID = 0
|
||||||
|
|
||||||
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
|
def __init__(
|
||||||
|
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||||
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._database_config = database_config
|
self._database_config = database_config
|
||||||
|
@ -259,9 +284,9 @@ class Database(object):
|
||||||
|
|
||||||
self.updates = BackgroundUpdater(hs, self)
|
self.updates = BackgroundUpdater(hs, self)
|
||||||
|
|
||||||
self._previous_txn_total_time = 0
|
self._previous_txn_total_time = 0.0
|
||||||
self._current_txn_total_time = 0
|
self._current_txn_total_time = 0.0
|
||||||
self._previous_loop_ts = 0
|
self._previous_loop_ts = 0.0
|
||||||
|
|
||||||
# TODO(paul): These can eventually be removed once the metrics code
|
# TODO(paul): These can eventually be removed once the metrics code
|
||||||
# is running in mainline, and we have some nice monitoring frontends
|
# is running in mainline, and we have some nice monitoring frontends
|
||||||
|
@ -463,23 +488,23 @@ class Database(object):
|
||||||
sql_txn_timer.labels(desc).observe(duration)
|
sql_txn_timer.labels(desc).observe(duration)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def runInteraction(self, desc, func, *args, **kwargs):
|
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
|
||||||
"""Starts a transaction on the database and runs a given function
|
"""Starts a transaction on the database and runs a given function
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
desc (str): description of the transaction, for logging and metrics
|
desc: description of the transaction, for logging and metrics
|
||||||
func (func): callback function, which will be called with a
|
func: callback function, which will be called with a
|
||||||
database transaction (twisted.enterprise.adbapi.Transaction) as
|
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||||
its first argument, followed by `args` and `kwargs`.
|
its first argument, followed by `args` and `kwargs`.
|
||||||
|
|
||||||
args (list): positional args to pass to `func`
|
args: positional args to pass to `func`
|
||||||
kwargs (dict): named args to pass to `func`
|
kwargs: named args to pass to `func`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: The result of func
|
Deferred: The result of func
|
||||||
"""
|
"""
|
||||||
after_callbacks = []
|
after_callbacks = [] # type: List[_CallbackListEntry]
|
||||||
exception_callbacks = []
|
exception_callbacks = [] # type: List[_CallbackListEntry]
|
||||||
|
|
||||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||||
|
@ -505,15 +530,15 @@ class Database(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def runWithConnection(self, func, *args, **kwargs):
|
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
|
||||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
func (func): callback function, which will be called with a
|
func: callback function, which will be called with a
|
||||||
database connection (twisted.enterprise.adbapi.Connection) as
|
database connection (twisted.enterprise.adbapi.Connection) as
|
||||||
its first argument, followed by `args` and `kwargs`.
|
its first argument, followed by `args` and `kwargs`.
|
||||||
args (list): positional args to pass to `func`
|
args: positional args to pass to `func`
|
||||||
kwargs (dict): named args to pass to `func`
|
kwargs: named args to pass to `func`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: The result of func
|
Deferred: The result of func
|
||||||
|
@ -800,7 +825,7 @@ class Database(object):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# We didn't find any existing rows, so insert a new one
|
# We didn't find any existing rows, so insert a new one
|
||||||
allvalues = {}
|
allvalues = {} # type: Dict[str, Any]
|
||||||
allvalues.update(keyvalues)
|
allvalues.update(keyvalues)
|
||||||
allvalues.update(values)
|
allvalues.update(values)
|
||||||
allvalues.update(insertion_values)
|
allvalues.update(insertion_values)
|
||||||
|
@ -829,7 +854,7 @@ class Database(object):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
allvalues = {}
|
allvalues = {} # type: Dict[str, Any]
|
||||||
allvalues.update(keyvalues)
|
allvalues.update(keyvalues)
|
||||||
allvalues.update(insertion_values)
|
allvalues.update(insertion_values)
|
||||||
|
|
||||||
|
@ -916,7 +941,7 @@ class Database(object):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
allnames = []
|
allnames = [] # type: List[str]
|
||||||
allnames.extend(key_names)
|
allnames.extend(key_names)
|
||||||
allnames.extend(value_names)
|
allnames.extend(value_names)
|
||||||
|
|
||||||
|
@ -1100,7 +1125,7 @@ class Database(object):
|
||||||
keyvalues : dict of column names and values to select the rows with
|
keyvalues : dict of column names and values to select the rows with
|
||||||
retcols : list of strings giving the names of the columns to return
|
retcols : list of strings giving the names of the columns to return
|
||||||
"""
|
"""
|
||||||
results = []
|
results = [] # type: List[Dict[str, Any]]
|
||||||
|
|
||||||
if not iterable:
|
if not iterable:
|
||||||
return results
|
return results
|
||||||
|
@ -1439,7 +1464,7 @@ class Database(object):
|
||||||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||||
|
|
||||||
where_clause = "WHERE " if filters or keyvalues else ""
|
where_clause = "WHERE " if filters or keyvalues else ""
|
||||||
arg_list = []
|
arg_list = [] # type: List[Any]
|
||||||
if filters:
|
if filters:
|
||||||
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
|
||||||
arg_list += list(filters.values())
|
arg_list += list(filters.values())
|
||||||
|
|
|
@ -12,29 +12,31 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
from ._base import IncorrectDatabaseSetup
|
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
|
||||||
from .postgres import PostgresEngine
|
from .postgres import PostgresEngine
|
||||||
from .sqlite import Sqlite3Engine
|
from .sqlite import Sqlite3Engine
|
||||||
|
|
||||||
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
|
|
||||||
|
|
||||||
|
def create_engine(database_config) -> BaseDatabaseEngine:
|
||||||
def create_engine(database_config):
|
|
||||||
name = database_config["name"]
|
name = database_config["name"]
|
||||||
engine_class = SUPPORTED_MODULE.get(name, None)
|
|
||||||
|
|
||||||
if engine_class:
|
if name == "sqlite3":
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
return Sqlite3Engine(sqlite3, database_config)
|
||||||
|
|
||||||
|
if name == "psycopg2":
|
||||||
# pypy requires psycopg2cffi rather than psycopg2
|
# pypy requires psycopg2cffi rather than psycopg2
|
||||||
if name == "psycopg2" and platform.python_implementation() == "PyPy":
|
if platform.python_implementation() == "PyPy":
|
||||||
name = "psycopg2cffi"
|
import psycopg2cffi as psycopg2 # type: ignore
|
||||||
module = importlib.import_module(name)
|
else:
|
||||||
return engine_class(module, database_config)
|
import psycopg2 # type: ignore
|
||||||
|
|
||||||
|
return PostgresEngine(psycopg2, database_config)
|
||||||
|
|
||||||
raise RuntimeError("Unsupported database engine '%s'" % (name,))
|
raise RuntimeError("Unsupported database engine '%s'" % (name,))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
|
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
|
||||||
|
|
|
@ -12,7 +12,94 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from synapse.storage.types import Connection
|
||||||
|
|
||||||
|
|
||||||
class IncorrectDatabaseSetup(RuntimeError):
|
class IncorrectDatabaseSetup(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
ConnectionType = TypeVar("ConnectionType", bound=Connection)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
|
||||||
|
def __init__(self, module, database_config: dict):
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def single_threaded(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def can_native_upsert(self) -> bool:
|
||||||
|
"""
|
||||||
|
Do we support native UPSERTs?
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def supports_tuple_comparison(self) -> bool:
|
||||||
|
"""
|
||||||
|
Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def supports_using_any_list(self) -> bool:
|
||||||
|
"""
|
||||||
|
Do we support using `a = ANY(?)` and passing a list
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def check_database(
|
||||||
|
self, db_conn: ConnectionType, allow_outdated_version: bool = False
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def check_new_database(self, txn) -> None:
|
||||||
|
"""Gets called when setting up a brand new database. This allows us to
|
||||||
|
apply stricter checks on new databases versus existing database.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def convert_param_style(self, sql: str) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def on_new_connection(self, db_conn: ConnectionType) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_deadlock(self, error: Exception) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def is_connection_closed(self, conn: ConnectionType) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def lock_table(self, txn, table: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next_state_group_id(self, txn) -> int:
|
||||||
|
"""Returns an int that can be used as a new state_group ID
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def server_version(self) -> str:
|
||||||
|
"""Gets a string giving the server version. For example: '3.22.0'
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -15,16 +15,14 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ._base import IncorrectDatabaseSetup
|
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PostgresEngine(object):
|
class PostgresEngine(BaseDatabaseEngine):
|
||||||
single_threaded = False
|
|
||||||
|
|
||||||
def __init__(self, database_module, database_config):
|
def __init__(self, database_module, database_config):
|
||||||
self.module = database_module
|
super().__init__(database_module, database_config)
|
||||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||||
|
|
||||||
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
|
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
|
||||||
|
@ -36,6 +34,10 @@ class PostgresEngine(object):
|
||||||
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
||||||
self._version = None # unknown as yet
|
self._version = None # unknown as yet
|
||||||
|
|
||||||
|
@property
|
||||||
|
def single_threaded(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||||
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
# Get the version of PostgreSQL that we're using. As per the psycopg2
|
||||||
# docs: The number is formed by converting the major, minor, and
|
# docs: The number is formed by converting the major, minor, and
|
||||||
|
|
|
@ -12,16 +12,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
from synapse.storage.engines import BaseDatabaseEngine
|
||||||
|
|
||||||
class Sqlite3Engine(object):
|
|
||||||
single_threaded = True
|
|
||||||
|
|
||||||
|
class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
|
||||||
def __init__(self, database_module, database_config):
|
def __init__(self, database_module, database_config):
|
||||||
self.module = database_module
|
super().__init__(database_module, database_config)
|
||||||
|
|
||||||
database = database_config.get("args", {}).get("database")
|
database = database_config.get("args", {}).get("database")
|
||||||
self._is_in_memory = database in (None, ":memory:",)
|
self._is_in_memory = database in (None, ":memory:",)
|
||||||
|
@ -31,6 +31,10 @@ class Sqlite3Engine(object):
|
||||||
self._current_state_group_id = None
|
self._current_state_group_id = None
|
||||||
self._current_state_group_id_lock = threading.Lock()
|
self._current_state_group_id_lock = threading.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def single_threaded(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_native_upsert(self):
|
def can_native_upsert(self):
|
||||||
"""
|
"""
|
||||||
|
@ -68,7 +72,6 @@ class Sqlite3Engine(object):
|
||||||
return sql
|
return sql
|
||||||
|
|
||||||
def on_new_connection(self, db_conn):
|
def on_new_connection(self, db_conn):
|
||||||
|
|
||||||
# We need to import here to avoid an import loop.
|
# We need to import here to avoid an import loop.
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
|
|
65
synapse/storage/types.py
Normal file
65
synapse/storage/types.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Iterable, Iterator, List, Tuple
|
||||||
|
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Cursor(Protocol):
|
||||||
|
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
def fetchall(self) -> List[Tuple]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def fetchone(self) -> Tuple:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> Any:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rowcount(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Tuple]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Connection(Protocol):
|
||||||
|
def cursor(self) -> Cursor:
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def rollback(self, *args, **kwargs) -> None:
|
||||||
|
...
|
5
tox.ini
5
tox.ini
|
@ -168,7 +168,6 @@ commands=
|
||||||
coverage html
|
coverage html
|
||||||
|
|
||||||
[testenv:mypy]
|
[testenv:mypy]
|
||||||
basepython = python3.7
|
|
||||||
skip_install = True
|
skip_install = True
|
||||||
deps =
|
deps =
|
||||||
{[base]deps}
|
{[base]deps}
|
||||||
|
@ -179,7 +178,8 @@ env =
|
||||||
extras = all
|
extras = all
|
||||||
commands = mypy \
|
commands = mypy \
|
||||||
synapse/api \
|
synapse/api \
|
||||||
synapse/config/ \
|
synapse/appservice \
|
||||||
|
synapse/config \
|
||||||
synapse/events/spamcheck.py \
|
synapse/events/spamcheck.py \
|
||||||
synapse/federation/sender \
|
synapse/federation/sender \
|
||||||
synapse/federation/transport \
|
synapse/federation/transport \
|
||||||
|
@ -192,6 +192,7 @@ commands = mypy \
|
||||||
synapse/rest \
|
synapse/rest \
|
||||||
synapse/spam_checker_api \
|
synapse/spam_checker_api \
|
||||||
synapse/storage/engines \
|
synapse/storage/engines \
|
||||||
|
synapse/storage/database.py \
|
||||||
synapse/streams
|
synapse/streams
|
||||||
|
|
||||||
# To find all folders that pass mypy you run:
|
# To find all folders that pass mypy you run:
|
||||||
|
|
Loading…
Reference in a new issue