0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 13:53:52 +01:00

Type hint the constructors of the data store classes (#11555)

This commit is contained in:
Sean Quah 2021-12-13 17:05:00 +00:00 committed by GitHub
parent 1abfb15f07
commit 5305a5e881
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 351 additions and 87 deletions

1
changelog.d/11555.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to storage classes.

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator
@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[ self._cache_id_gen: Optional[

View file

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -25,7 +25,12 @@ if TYPE_CHECKING:
class SlavedClientIpStore(BaseSlavedStore): class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen: LruCache[tuple, int] = LruCache( self.client_ip_last_seen: LruCache[tuple, int] = LruCache(

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -27,7 +27,12 @@ if TYPE_CHECKING:
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import ( from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore, RelationsWorkerStore,
BaseSlavedStore, BaseSlavedStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token() events_max = self._stream_id_gen.get_current_token()

View file

@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,7 +24,12 @@ if TYPE_CHECKING:
class SlavedFilteringStore(BaseSlavedStore): class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired # Filters are immutable so this cache doesn't need to be expired

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -26,7 +26,12 @@ if TYPE_CHECKING:
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.hs = hs self.hs = hs

View file

@ -17,10 +17,8 @@ import logging
from abc import ABCMeta from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util import json_decoder from synapse.util import json_decoder
@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database). per data store (and not one per physical database).
""" """
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine

View file

@ -175,7 +175,7 @@ class LoggingDatabaseConnection:
def rollback(self) -> None: def rollback(self) -> None:
self.conn.rollback() self.conn.rollback()
def __enter__(self) -> "Connection": def __enter__(self) -> "LoggingDatabaseConnection":
self.conn.__enter__() self.conn.__enter__()
return self return self

View file

@ -18,7 +18,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
@ -129,7 +129,12 @@ class DataStore(
LockStore, LockStore,
SessionStore, SessionStore,
): ):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.database_engine = database.engine self.database_engine = database.engine

View file

@ -24,9 +24,8 @@ from synapse.appservice import (
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -58,7 +57,12 @@ def _make_exclusive_regex(
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files hs.hostname, hs.config.appservice.app_service_config_files
) )

View file

@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow, EventsStreamEventRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()

View file

@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
from synapse.events.utils import prune_event_dict from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder from synapse.util import json_encoder
@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if ( if (

View file

@ -26,7 +26,6 @@ from synapse.storage.database import (
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age self.user_ips_max_age = hs.config.server.user_ips_max_age
@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
# (user_id, access_token, ip,) -> last_seen # (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](

View file

@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(

View file

@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore): class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
@ -953,7 +959,12 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -1085,7 +1096,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies

View file

@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
@ -1514,7 +1523,12 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(

View file

@ -20,7 +20,11 @@ from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn # These get correctly set by _find_stream_orderings_for_times_txn
@ -910,7 +919,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(

View file

@ -41,10 +41,13 @@ from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
@ -95,7 +98,7 @@ class PersistEventsStore:
hs: "HomeServer", hs: "HomeServer",
db: DatabasePool, db: DatabasePool,
main_data_store: "DataStore", main_data_store: "DataStore",
db_conn: Connection, db_conn: LoggingDatabaseConnection,
): ):
self.hs = hs self.hs = hs
self.db_pool = db self.db_pool = db

View file

@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
@ -83,7 +84,12 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore): class EventsBackgroundUpdatesStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(

View file

@ -19,8 +19,7 @@ from typing_extensions import TypedDict
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.types import Connection
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore): class GroupServerWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
database.updates.register_background_index_update( database.updates.register_background_index_update(
update_name="local_group_updates_index", update_name="local_group_updates_index",
index_name="local_group_updates_stream_id_index", index_name="local_group_updates_stream_id_index",

View file

@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
from synapse.storage.types import Connection DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
`last_renewed_ts` column with the current time. `last_renewed_ts` column with the current time.
""" """
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._reactor = hs.get_reactor() self._reactor = hs.get_reactor()

View file

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import ( from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore, EventPushActionsWorkerStore,
) )
@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics. stats and prometheus metrics.
""" """
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Read the extrems every 60 minutes # Read the extrems every 60 minutes

View file

@ -16,7 +16,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
make_in_list_sql_clause,
)
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -31,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.hs = hs self.hs = hs
@ -213,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._mau_stats_only = hs.config.server.mau_stats_only self._mau_stats_only = hs.config.server.mau_stats_only

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from synapse.api.presence import PresenceState, UserPresenceState from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)

View file

@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore
@ -81,7 +81,12 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View file

@ -32,7 +32,11 @@ from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
@ -47,7 +51,12 @@ logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine): if isinstance(database.engine, PostgresEngine):

View file

@ -24,7 +24,11 @@ from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.search import SearchStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
@ -72,7 +76,12 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore): class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -1050,7 +1059,12 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
class RoomBackgroundUpdateStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
@ -1435,7 +1449,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config

View file

@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import ( from synapse.storage.roommember import (
@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None: async def forget(self, user_id: str, room_id: str) -> None:

View file

@ -20,7 +20,11 @@ from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if not hs.config.server.enable_search: if not hs.config.server.enable_search:
@ -358,7 +367,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore): class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys): async def search_msgs(self, room_ids, search_term, keys):

View file

@ -22,7 +22,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -56,7 +60,12 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.""" """The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion: async def get_room_version(self, room_id: str) -> RoomVersion:
@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname
@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events. * `state_groups_state`: Maps state group to state events.
""" """
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)

View file

@ -24,7 +24,7 @@ from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -96,7 +96,12 @@ class UserSortOrder(Enum):
class StatsStore(StateDeltasStore): class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.server_name = hs.hostname self.server_name = hs.hostname

View file

@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause, make_in_list_sql_clause,
) )
@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
which can be called in the initializer. which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()

View file

@ -22,7 +22,11 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -71,7 +75,12 @@ class DestinationRetryTimings:
class TransactionWorkerStore(CacheInvalidationWorkerStore): class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:

View file

@ -32,11 +32,14 @@ if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
db_conn: Connection, db_conn: LoggingDatabaseConnection,
hs: "HomeServer", hs: "HomeServer",
) -> None: ) -> None:
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)