forked from MirrorHub/synapse
Rename storage classes (#12913)
This commit is contained in:
parent
e541bb9eed
commit
1e453053cb
53 changed files with 708 additions and 551 deletions
1
changelog.d/12913.misc
Normal file
1
changelog.d/12913.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Rename storage classes.
|
|
@ -22,7 +22,7 @@ from synapse.events import EventBase
|
||||||
from synapse.types import JsonDict, StateMap
|
from synapse.types import JsonDict, StateMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.storage import Storage
|
from synapse.storage.controllers import StorageControllers
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class EventContext:
|
||||||
incomplete state.
|
incomplete state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_storage: "Storage"
|
_storage: "StorageControllers"
|
||||||
rejected: Union[Literal[False], str] = False
|
rejected: Union[Literal[False], str] = False
|
||||||
_state_group: Optional[int] = None
|
_state_group: Optional[int] = None
|
||||||
state_group_before_event: Optional[int] = None
|
state_group_before_event: Optional[int] = None
|
||||||
|
@ -97,7 +97,7 @@ class EventContext:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_state(
|
def with_state(
|
||||||
storage: "Storage",
|
storage: "StorageControllers",
|
||||||
state_group: Optional[int],
|
state_group: Optional[int],
|
||||||
state_group_before_event: Optional[int],
|
state_group_before_event: Optional[int],
|
||||||
state_delta_due_to_event: Optional[StateMap[str]],
|
state_delta_due_to_event: Optional[StateMap[str]],
|
||||||
|
@ -117,7 +117,7 @@ class EventContext:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_outlier(
|
def for_outlier(
|
||||||
storage: "Storage",
|
storage: "StorageControllers",
|
||||||
) -> "EventContext":
|
) -> "EventContext":
|
||||||
"""Return an EventContext instance suitable for persisting an outlier event"""
|
"""Return an EventContext instance suitable for persisting an outlier event"""
|
||||||
return EventContext(storage=storage)
|
return EventContext(storage=storage)
|
||||||
|
@ -147,7 +147,7 @@ class EventContext:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
|
def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
|
||||||
"""Converts a dict that was produced by `serialize` back into a
|
"""Converts a dict that was produced by `serialize` back into a
|
||||||
EventContext.
|
EventContext.
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,6 @@ class FederationServer(FederationBase):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.handler = hs.get_federation_handler()
|
self.handler = hs.get_federation_handler()
|
||||||
self.storage = hs.get_storage()
|
|
||||||
self._spam_checker = hs.get_spam_checker()
|
self._spam_checker = hs.get_spam_checker()
|
||||||
self._federation_event_handler = hs.get_federation_event_handler()
|
self._federation_event_handler = hs.get_federation_event_handler()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
|
@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
|
||||||
class AdminHandler:
|
class AdminHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
|
||||||
async def get_whois(self, user: UserID) -> JsonDict:
|
async def get_whois(self, user: UserID) -> JsonDict:
|
||||||
connections = []
|
connections = []
|
||||||
|
@ -197,7 +197,9 @@ class AdminHandler:
|
||||||
|
|
||||||
from_key = events[-1].internal_metadata.after
|
from_key = events[-1].internal_metadata.after
|
||||||
|
|
||||||
events = await filter_events_for_client(self.storage, user_id, events)
|
events = await filter_events_for_client(
|
||||||
|
self._storage_controllers, user_id, events
|
||||||
|
)
|
||||||
|
|
||||||
writer.write_events(room_id, events)
|
writer.write_events(room_id, events)
|
||||||
|
|
||||||
|
@ -233,7 +235,9 @@ class AdminHandler:
|
||||||
for event_id in extremities:
|
for event_id in extremities:
|
||||||
if not event_to_unseen_prevs[event_id]:
|
if not event_to_unseen_prevs[event_id]:
|
||||||
continue
|
continue
|
||||||
state = await self.state_storage.get_state_for_event(event_id)
|
state = await self._state_storage_controller.get_state_for_event(
|
||||||
|
event_id
|
||||||
|
)
|
||||||
writer.write_state(room_id, event_id, state)
|
writer.write_state(room_id, event_id, state)
|
||||||
|
|
||||||
return writer.finished()
|
return writer.finished()
|
||||||
|
|
|
@ -71,7 +71,7 @@ class DeviceWorkerHandler:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.state_storage = hs.get_storage().state
|
self._state_storage = hs.get_storage_controllers().state
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ class DeviceWorkerHandler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# mapping from event_id -> state_dict
|
# mapping from event_id -> state_dict
|
||||||
prev_state_ids = await self.state_storage.get_state_ids_for_events(
|
prev_state_ids = await self._state_storage.get_state_ids_for_events(
|
||||||
event_ids
|
event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,7 @@ class EventStreamHandler:
|
||||||
class EventHandler:
|
class EventHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
async def get_event(
|
async def get_event(
|
||||||
self,
|
self,
|
||||||
|
@ -177,7 +177,7 @@ class EventHandler:
|
||||||
is_peeking = user.to_string() not in users
|
is_peeking = user.to_string() not in users
|
||||||
|
|
||||||
filtered = await filter_events_for_client(
|
filtered = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), [event], is_peeking=is_peeking
|
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
|
||||||
)
|
)
|
||||||
|
|
||||||
if not filtered:
|
if not filtered:
|
||||||
|
|
|
@ -125,8 +125,8 @@ class FederationHandler:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
self.federation_client = hs.get_federation_client()
|
self.federation_client = hs.get_federation_client()
|
||||||
self.state_handler = hs.get_state_handler()
|
self.state_handler = hs.get_state_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
@ -324,7 +324,7 @@ class FederationHandler:
|
||||||
# We set `check_history_visibility_only` as we might otherwise get false
|
# We set `check_history_visibility_only` as we might otherwise get false
|
||||||
# positives from users having been erased.
|
# positives from users having been erased.
|
||||||
filtered_extremities = await filter_events_for_server(
|
filtered_extremities = await filter_events_for_server(
|
||||||
self.storage,
|
self._storage_controllers,
|
||||||
self.server_name,
|
self.server_name,
|
||||||
events_to_check,
|
events_to_check,
|
||||||
redact=False,
|
redact=False,
|
||||||
|
@ -660,7 +660,7 @@ class FederationHandler:
|
||||||
# in the invitee's sync stream. It is stripped out for all other local users.
|
# in the invitee's sync stream. It is stripped out for all other local users.
|
||||||
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
|
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
|
||||||
|
|
||||||
context = EventContext.for_outlier(self.storage)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||||
event.room_id, [(event, context)]
|
event.room_id, [(event, context)]
|
||||||
)
|
)
|
||||||
|
@ -849,7 +849,7 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
context = EventContext.for_outlier(self.storage)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
await self._federation_event_handler.persist_events_and_notify(
|
await self._federation_event_handler.persist_events_and_notify(
|
||||||
event.room_id, [(event, context)]
|
event.room_id, [(event, context)]
|
||||||
)
|
)
|
||||||
|
@ -878,7 +878,7 @@ class FederationHandler:
|
||||||
|
|
||||||
await self.federation_client.send_leave(host_list, event)
|
await self.federation_client.send_leave(host_list, event)
|
||||||
|
|
||||||
context = EventContext.for_outlier(self.storage)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||||
event.room_id, [(event, context)]
|
event.room_id, [(event, context)]
|
||||||
)
|
)
|
||||||
|
@ -1027,7 +1027,7 @@ class FederationHandler:
|
||||||
if event.internal_metadata.outlier:
|
if event.internal_metadata.outlier:
|
||||||
raise NotFoundError("State not known at event %s" % (event_id,))
|
raise NotFoundError("State not known at event %s" % (event_id,))
|
||||||
|
|
||||||
state_groups = await self.state_storage.get_state_groups_ids(
|
state_groups = await self._state_storage_controller.get_state_groups_ids(
|
||||||
room_id, [event_id]
|
room_id, [event_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1078,7 +1078,9 @@ class FederationHandler:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
events = await filter_events_for_server(self.storage, origin, events)
|
events = await filter_events_for_server(
|
||||||
|
self._storage_controllers, origin, events
|
||||||
|
)
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
|
@ -1109,7 +1111,9 @@ class FederationHandler:
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
events = await filter_events_for_server(self.storage, origin, [event])
|
events = await filter_events_for_server(
|
||||||
|
self._storage_controllers, origin, [event]
|
||||||
|
)
|
||||||
event = events[0]
|
event = events[0]
|
||||||
return event
|
return event
|
||||||
else:
|
else:
|
||||||
|
@ -1138,7 +1142,7 @@ class FederationHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_events = await filter_events_for_server(
|
missing_events = await filter_events_for_server(
|
||||||
self.storage, origin, missing_events
|
self._storage_controllers, origin, missing_events
|
||||||
)
|
)
|
||||||
|
|
||||||
return missing_events
|
return missing_events
|
||||||
|
@ -1480,9 +1484,11 @@ class FederationHandler:
|
||||||
# clear the lazy-loading flag.
|
# clear the lazy-loading flag.
|
||||||
logger.info("Updating current state for %s", room_id)
|
logger.info("Updating current state for %s", room_id)
|
||||||
assert (
|
assert (
|
||||||
self.storage.persistence is not None
|
self._storage_controllers.persistence is not None
|
||||||
), "TODO(faster_joins): support for workers"
|
), "TODO(faster_joins): support for workers"
|
||||||
await self.storage.persistence.update_current_state(room_id)
|
await self._storage_controllers.persistence.update_current_state(
|
||||||
|
room_id
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Clearing partial-state flag for %s", room_id)
|
logger.info("Clearing partial-state flag for %s", room_id)
|
||||||
success = await self.store.clear_partial_state_room(room_id)
|
success = await self.store.clear_partial_state_room(room_id)
|
||||||
|
|
|
@ -98,8 +98,8 @@ class FederationEventHandler:
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._store = hs.get_datastores().main
|
self._store = hs.get_datastores().main
|
||||||
self._storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._state_storage = self._storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
|
||||||
self._state_handler = hs.get_state_handler()
|
self._state_handler = hs.get_state_handler()
|
||||||
self._event_creation_handler = hs.get_event_creation_handler()
|
self._event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
@ -535,7 +535,9 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
await self._store.update_state_for_partial_state_event(event, context)
|
await self._store.update_state_for_partial_state_event(event, context)
|
||||||
self._state_storage.notify_event_un_partial_stated(event.event_id)
|
self._state_storage_controller.notify_event_un_partial_stated(
|
||||||
|
event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
async def backfill(
|
async def backfill(
|
||||||
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
|
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
|
||||||
|
@ -835,7 +837,9 @@ class FederationEventHandler:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the state of the events we know about
|
# Get the state of the events we know about
|
||||||
ours = await self._state_storage.get_state_groups_ids(room_id, seen)
|
ours = await self._state_storage_controller.get_state_groups_ids(
|
||||||
|
room_id, seen
|
||||||
|
)
|
||||||
|
|
||||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||||
state_maps: List[StateMap[str]] = list(ours.values())
|
state_maps: List[StateMap[str]] = list(ours.values())
|
||||||
|
@ -1436,7 +1440,7 @@ class FederationEventHandler:
|
||||||
# we're not bothering about room state, so flag the event as an outlier.
|
# we're not bothering about room state, so flag the event as an outlier.
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
|
|
||||||
context = EventContext.for_outlier(self._storage)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
try:
|
try:
|
||||||
validate_event_for_room_version(room_version_obj, event)
|
validate_event_for_room_version(room_version_obj, event)
|
||||||
check_auth_rules_for_event(room_version_obj, event, auth)
|
check_auth_rules_for_event(room_version_obj, event, auth)
|
||||||
|
@ -1613,7 +1617,7 @@ class FederationEventHandler:
|
||||||
# given state at the event. This should correctly handle cases
|
# given state at the event. This should correctly handle cases
|
||||||
# like bans, especially with state res v2.
|
# like bans, especially with state res v2.
|
||||||
|
|
||||||
state_sets_d = await self._state_storage.get_state_groups_ids(
|
state_sets_d = await self._state_storage_controller.get_state_groups_ids(
|
||||||
event.room_id, extrem_ids
|
event.room_id, extrem_ids
|
||||||
)
|
)
|
||||||
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||||
|
@ -1885,7 +1889,7 @@ class FederationEventHandler:
|
||||||
|
|
||||||
# create a new state group as a delta from the existing one.
|
# create a new state group as a delta from the existing one.
|
||||||
prev_group = context.state_group
|
prev_group = context.state_group
|
||||||
state_group = await self._state_storage.store_state_group(
|
state_group = await self._state_storage_controller.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
|
@ -1894,7 +1898,7 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
return EventContext.with_state(
|
return EventContext.with_state(
|
||||||
storage=self._storage,
|
storage=self._storage_controllers,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
state_group_before_event=context.state_group_before_event,
|
state_group_before_event=context.state_group_before_event,
|
||||||
state_delta_due_to_event=state_updates,
|
state_delta_due_to_event=state_updates,
|
||||||
|
@ -1984,11 +1988,14 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
return result["max_stream_id"]
|
return result["max_stream_id"]
|
||||||
else:
|
else:
|
||||||
assert self._storage.persistence
|
assert self._storage_controllers.persistence
|
||||||
|
|
||||||
# Note that this returns the events that were persisted, which may not be
|
# Note that this returns the events that were persisted, which may not be
|
||||||
# the same as were passed in if some were deduplicated due to transaction IDs.
|
# the same as were passed in if some were deduplicated due to transaction IDs.
|
||||||
events, max_stream_token = await self._storage.persistence.persist_events(
|
(
|
||||||
|
events,
|
||||||
|
max_stream_token,
|
||||||
|
) = await self._storage_controllers.persistence.persist_events(
|
||||||
event_and_contexts, backfilled=backfilled
|
event_and_contexts, backfilled=backfilled
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -67,8 +67,8 @@ class InitialSyncHandler:
|
||||||
]
|
]
|
||||||
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
|
||||||
async def snapshot_all_rooms(
|
async def snapshot_all_rooms(
|
||||||
self,
|
self,
|
||||||
|
@ -198,7 +198,8 @@ class InitialSyncHandler:
|
||||||
event.stream_ordering,
|
event.stream_ordering,
|
||||||
)
|
)
|
||||||
deferred_room_state = run_in_background(
|
deferred_room_state = run_in_background(
|
||||||
self.state_storage.get_state_for_events, [event.event_id]
|
self._state_storage_controller.get_state_for_events,
|
||||||
|
[event.event_id],
|
||||||
).addCallback(
|
).addCallback(
|
||||||
lambda states: cast(StateMap[EventBase], states[event.event_id])
|
lambda states: cast(StateMap[EventBase], states[event.event_id])
|
||||||
)
|
)
|
||||||
|
@ -218,7 +219,7 @@ class InitialSyncHandler:
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
messages = await filter_events_for_client(
|
messages = await filter_events_for_client(
|
||||||
self.storage, user_id, messages
|
self._storage_controllers, user_id, messages
|
||||||
)
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
||||||
|
@ -355,7 +356,9 @@ class InitialSyncHandler:
|
||||||
member_event_id: str,
|
member_event_id: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
room_state = await self.state_storage.get_state_for_event(member_event_id)
|
room_state = await self._state_storage_controller.get_state_for_event(
|
||||||
|
member_event_id
|
||||||
|
)
|
||||||
|
|
||||||
limit = pagin_config.limit if pagin_config else None
|
limit = pagin_config.limit if pagin_config else None
|
||||||
if limit is None:
|
if limit is None:
|
||||||
|
@ -369,7 +372,7 @@ class InitialSyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = await filter_events_for_client(
|
messages = await filter_events_for_client(
|
||||||
self.storage, user_id, messages, is_peeking=is_peeking
|
self._storage_controllers, user_id, messages, is_peeking=is_peeking
|
||||||
)
|
)
|
||||||
|
|
||||||
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
|
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
|
||||||
|
@ -474,7 +477,7 @@ class InitialSyncHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = await filter_events_for_client(
|
messages = await filter_events_for_client(
|
||||||
self.storage, user_id, messages, is_peeking=is_peeking
|
self._storage_controllers, user_id, messages, is_peeking=is_peeking
|
||||||
)
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
|
||||||
|
|
|
@ -84,8 +84,8 @@ class MessageHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
|
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class MessageHandler:
|
||||||
assert (
|
assert (
|
||||||
membership_event_id is not None
|
membership_event_id is not None
|
||||||
), "check_user_in_room_or_world_readable returned invalid data"
|
), "check_user_in_room_or_world_readable returned invalid data"
|
||||||
room_state = await self.state_storage.get_state_for_events(
|
room_state = await self._state_storage_controller.get_state_for_events(
|
||||||
[membership_event_id], StateFilter.from_types([key])
|
[membership_event_id], StateFilter.from_types([key])
|
||||||
)
|
)
|
||||||
data = room_state[membership_event_id].get(key)
|
data = room_state[membership_event_id].get(key)
|
||||||
|
@ -193,7 +193,7 @@ class MessageHandler:
|
||||||
|
|
||||||
# check whether the user is in the room at that time to determine
|
# check whether the user is in the room at that time to determine
|
||||||
# whether they should be treated as peeking.
|
# whether they should be treated as peeking.
|
||||||
state_map = await self.state_storage.get_state_for_event(
|
state_map = await self._state_storage_controller.get_state_for_event(
|
||||||
last_event.event_id,
|
last_event.event_id,
|
||||||
StateFilter.from_types([(EventTypes.Member, user_id)]),
|
StateFilter.from_types([(EventTypes.Member, user_id)]),
|
||||||
)
|
)
|
||||||
|
@ -206,7 +206,7 @@ class MessageHandler:
|
||||||
is_peeking = not joined
|
is_peeking = not joined
|
||||||
|
|
||||||
visible_events = await filter_events_for_client(
|
visible_events = await filter_events_for_client(
|
||||||
self.storage,
|
self._storage_controllers,
|
||||||
user_id,
|
user_id,
|
||||||
[last_event],
|
[last_event],
|
||||||
filter_send_to_client=False,
|
filter_send_to_client=False,
|
||||||
|
@ -214,9 +214,11 @@ class MessageHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if visible_events:
|
if visible_events:
|
||||||
room_state_events = await self.state_storage.get_state_for_events(
|
room_state_events = (
|
||||||
|
await self._state_storage_controller.get_state_for_events(
|
||||||
[last_event.event_id], state_filter=state_filter
|
[last_event.event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
)
|
||||||
room_state: Mapping[Any, EventBase] = room_state_events[
|
room_state: Mapping[Any, EventBase] = room_state_events[
|
||||||
last_event.event_id
|
last_event.event_id
|
||||||
]
|
]
|
||||||
|
@ -244,9 +246,11 @@ class MessageHandler:
|
||||||
assert (
|
assert (
|
||||||
membership_event_id is not None
|
membership_event_id is not None
|
||||||
), "check_user_in_room_or_world_readable returned invalid data"
|
), "check_user_in_room_or_world_readable returned invalid data"
|
||||||
room_state_events = await self.state_storage.get_state_for_events(
|
room_state_events = (
|
||||||
|
await self._state_storage_controller.get_state_for_events(
|
||||||
[membership_event_id], state_filter=state_filter
|
[membership_event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
)
|
||||||
room_state = room_state_events[membership_event_id]
|
room_state = room_state_events[membership_event_id]
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
@ -402,7 +406,7 @@ class EventCreationHandler:
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.validator = EventValidator()
|
self.validator = EventValidator()
|
||||||
|
@ -1032,7 +1036,7 @@ class EventCreationHandler:
|
||||||
# after it is created
|
# after it is created
|
||||||
if builder.internal_metadata.outlier:
|
if builder.internal_metadata.outlier:
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
context = EventContext.for_outlier(self.storage)
|
context = EventContext.for_outlier(self._storage_controllers)
|
||||||
elif (
|
elif (
|
||||||
event.type == EventTypes.MSC2716_INSERTION
|
event.type == EventTypes.MSC2716_INSERTION
|
||||||
and state_event_ids
|
and state_event_ids
|
||||||
|
@ -1445,7 +1449,7 @@ class EventCreationHandler:
|
||||||
"""
|
"""
|
||||||
extra_users = extra_users or []
|
extra_users = extra_users or []
|
||||||
|
|
||||||
assert self.storage.persistence is not None
|
assert self._storage_controllers.persistence is not None
|
||||||
assert self._events_shard_config.should_handle(
|
assert self._events_shard_config.should_handle(
|
||||||
self._instance_name, event.room_id
|
self._instance_name, event.room_id
|
||||||
)
|
)
|
||||||
|
@ -1679,7 +1683,7 @@ class EventCreationHandler:
|
||||||
event,
|
event,
|
||||||
event_pos,
|
event_pos,
|
||||||
max_stream_token,
|
max_stream_token,
|
||||||
) = await self.storage.persistence.persist_event(
|
) = await self._storage_controllers.persistence.persist_event(
|
||||||
event, context=context, backfilled=backfilled
|
event, context=context, backfilled=backfilled
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -129,8 +129,8 @@ class PaginationHandler:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self._server_name = hs.hostname
|
self._server_name = hs.hostname
|
||||||
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
self._room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||||
|
@ -352,7 +352,7 @@ class PaginationHandler:
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
async with self.pagination_lock.write(room_id):
|
async with self.pagination_lock.write(room_id):
|
||||||
await self.storage.purge_events.purge_history(
|
await self._storage_controllers.purge_events.purge_history(
|
||||||
room_id, token, delete_local_events
|
room_id, token, delete_local_events
|
||||||
)
|
)
|
||||||
logger.info("[purge] complete")
|
logger.info("[purge] complete")
|
||||||
|
@ -414,7 +414,7 @@ class PaginationHandler:
|
||||||
if joined:
|
if joined:
|
||||||
raise SynapseError(400, "Users are still joined to this room")
|
raise SynapseError(400, "Users are still joined to this room")
|
||||||
|
|
||||||
await self.storage.purge_events.purge_room(room_id)
|
await self._storage_controllers.purge_events.purge_room(room_id)
|
||||||
|
|
||||||
async def get_messages(
|
async def get_messages(
|
||||||
self,
|
self,
|
||||||
|
@ -529,7 +529,10 @@ class PaginationHandler:
|
||||||
events = await event_filter.filter(events)
|
events = await event_filter.filter(events)
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
self._storage_controllers,
|
||||||
|
user_id,
|
||||||
|
events,
|
||||||
|
is_peeking=(member_event_id is None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# if after the filter applied there are no more events
|
# if after the filter applied there are no more events
|
||||||
|
@ -550,7 +553,7 @@ class PaginationHandler:
|
||||||
(EventTypes.Member, event.sender) for event in events
|
(EventTypes.Member, event.sender) for event in events
|
||||||
)
|
)
|
||||||
|
|
||||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||||
events[0].event_id, state_filter=state_filter
|
events[0].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -664,7 +667,7 @@ class PaginationHandler:
|
||||||
400, "Users are still joined to this room"
|
400, "Users are still joined to this room"
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.storage.purge_events.purge_room(room_id)
|
await self._storage_controllers.purge_events.purge_room(room_id)
|
||||||
|
|
||||||
logger.info("complete")
|
logger.info("complete")
|
||||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
|
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
|
||||||
|
|
|
@ -69,7 +69,7 @@ class BundledAggregations:
|
||||||
class RelationsHandler:
|
class RelationsHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._main_store = hs.get_datastores().main
|
self._main_store = hs.get_datastores().main
|
||||||
self._storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self._event_handler = hs.get_event_handler()
|
self._event_handler = hs.get_event_handler()
|
||||||
|
@ -143,7 +143,10 @@ class RelationsHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self._storage, user_id, events, is_peeking=(member_event_id is None)
|
self._storage_controllers,
|
||||||
|
user_id,
|
||||||
|
events,
|
||||||
|
is_peeking=(member_event_id is None),
|
||||||
)
|
)
|
||||||
|
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
|
|
|
@ -1192,8 +1192,8 @@ class RoomContextHandler:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
self._relations_handler = hs.get_relations_handler()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
|
|
||||||
async def get_event_context(
|
async def get_event_context(
|
||||||
|
@ -1236,7 +1236,10 @@ class RoomContextHandler:
|
||||||
if use_admin_priviledge:
|
if use_admin_priviledge:
|
||||||
return events
|
return events
|
||||||
return await filter_events_for_client(
|
return await filter_events_for_client(
|
||||||
self.storage, user.to_string(), events, is_peeking=is_peeking
|
self._storage_controllers,
|
||||||
|
user.to_string(),
|
||||||
|
events,
|
||||||
|
is_peeking=is_peeking,
|
||||||
)
|
)
|
||||||
|
|
||||||
event = await self.store.get_event(
|
event = await self.store.get_event(
|
||||||
|
@ -1293,7 +1296,7 @@ class RoomContextHandler:
|
||||||
# first? Shouldn't we be consistent with /sync?
|
# first? Shouldn't we be consistent with /sync?
|
||||||
# https://github.com/matrix-org/matrix-doc/issues/687
|
# https://github.com/matrix-org/matrix-doc/issues/687
|
||||||
|
|
||||||
state = await self.state_storage.get_state_for_events(
|
state = await self._state_storage_controller.get_state_for_events(
|
||||||
[last_event_id], state_filter=state_filter
|
[last_event_id], state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ class RoomBatchHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state_storage = hs.get_storage().state
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
@ -141,7 +141,7 @@ class RoomBatchHandler:
|
||||||
) = await self.store.get_max_depth_of(event_ids)
|
) = await self.store.get_max_depth_of(event_ids)
|
||||||
# mapping from (type, state_key) -> state_event_id
|
# mapping from (type, state_key) -> state_event_id
|
||||||
assert most_recent_event_id is not None
|
assert most_recent_event_id is not None
|
||||||
prev_state_map = await self.state_storage.get_state_ids_for_event(
|
prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
|
||||||
most_recent_event_id
|
most_recent_event_id
|
||||||
)
|
)
|
||||||
# List of state event ID's
|
# List of state event ID's
|
||||||
|
|
|
@ -55,8 +55,8 @@ class SearchHandler:
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self._relations_handler = hs.get_relations_handler()
|
self._relations_handler = hs.get_relations_handler()
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
|
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
|
||||||
|
@ -460,7 +460,7 @@ class SearchHandler:
|
||||||
filtered_events = await search_filter.filter([r["event"] for r in results])
|
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), filtered_events
|
self._storage_controllers, user.to_string(), filtered_events
|
||||||
)
|
)
|
||||||
|
|
||||||
events.sort(key=lambda e: -rank_map[e.event_id])
|
events.sort(key=lambda e: -rank_map[e.event_id])
|
||||||
|
@ -559,7 +559,7 @@ class SearchHandler:
|
||||||
filtered_events = await search_filter.filter([r["event"] for r in results])
|
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||||
|
|
||||||
events = await filter_events_for_client(
|
events = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), filtered_events
|
self._storage_controllers, user.to_string(), filtered_events
|
||||||
)
|
)
|
||||||
|
|
||||||
room_events.extend(events)
|
room_events.extend(events)
|
||||||
|
@ -644,11 +644,11 @@ class SearchHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
events_before = await filter_events_for_client(
|
events_before = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), res.events_before
|
self._storage_controllers, user.to_string(), res.events_before
|
||||||
)
|
)
|
||||||
|
|
||||||
events_after = await filter_events_for_client(
|
events_after = await filter_events_for_client(
|
||||||
self.storage, user.to_string(), res.events_after
|
self._storage_controllers, user.to_string(), res.events_after
|
||||||
)
|
)
|
||||||
|
|
||||||
context: JsonDict = {
|
context: JsonDict = {
|
||||||
|
@ -677,7 +677,7 @@ class SearchHandler:
|
||||||
[(EventTypes.Member, sender) for sender in senders]
|
[(EventTypes.Member, sender) for sender in senders]
|
||||||
)
|
)
|
||||||
|
|
||||||
state = await self.state_storage.get_state_for_event(
|
state = await self._state_storage_controller.get_state_for_event(
|
||||||
last_event_id, state_filter
|
last_event_id, state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -238,8 +238,8 @@ class SyncHandler:
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.state_storage = self.storage.state
|
self._state_storage_controller = self._storage_controllers.state
|
||||||
|
|
||||||
# TODO: flush cache entries on subsequent sync request.
|
# TODO: flush cache entries on subsequent sync request.
|
||||||
# Once we get the next /sync request (ie, one with the same access token
|
# Once we get the next /sync request (ie, one with the same access token
|
||||||
|
@ -512,7 +512,7 @@ class SyncHandler:
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
|
||||||
recents = await filter_events_for_client(
|
recents = await filter_events_for_client(
|
||||||
self.storage,
|
self._storage_controllers,
|
||||||
sync_config.user.to_string(),
|
sync_config.user.to_string(),
|
||||||
recents,
|
recents,
|
||||||
always_include_ids=current_state_ids,
|
always_include_ids=current_state_ids,
|
||||||
|
@ -580,7 +580,7 @@ class SyncHandler:
|
||||||
current_state_ids = frozenset(current_state_ids_map.values())
|
current_state_ids = frozenset(current_state_ids_map.values())
|
||||||
|
|
||||||
loaded_recents = await filter_events_for_client(
|
loaded_recents = await filter_events_for_client(
|
||||||
self.storage,
|
self._storage_controllers,
|
||||||
sync_config.user.to_string(),
|
sync_config.user.to_string(),
|
||||||
loaded_recents,
|
loaded_recents,
|
||||||
always_include_ids=current_state_ids,
|
always_include_ids=current_state_ids,
|
||||||
|
@ -630,7 +630,7 @@ class SyncHandler:
|
||||||
event: event of interest
|
event: event of interest
|
||||||
state_filter: The state filter used to fetch state from the database.
|
state_filter: The state filter used to fetch state from the database.
|
||||||
"""
|
"""
|
||||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||||
event.event_id, state_filter=state_filter or StateFilter.all()
|
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||||
)
|
)
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
|
@ -710,7 +710,7 @@ class SyncHandler:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
last_event = last_events[-1]
|
last_event = last_events[-1]
|
||||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||||
last_event.event_id,
|
last_event.event_id,
|
||||||
state_filter=StateFilter.from_types(
|
state_filter=StateFilter.from_types(
|
||||||
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
|
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
|
||||||
|
@ -889,14 +889,16 @@ class SyncHandler:
|
||||||
if full_state:
|
if full_state:
|
||||||
if batch:
|
if batch:
|
||||||
current_state_ids = (
|
current_state_ids = (
|
||||||
await self.state_storage.get_state_ids_for_event(
|
await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[-1].event_id, state_filter=state_filter
|
batch.events[-1].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
state_ids = (
|
||||||
|
await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[0].event_id, state_filter=state_filter
|
batch.events[0].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
current_state_ids = await self.get_state_at(
|
current_state_ids = await self.get_state_at(
|
||||||
|
@ -915,7 +917,7 @@ class SyncHandler:
|
||||||
elif batch.limited:
|
elif batch.limited:
|
||||||
if batch:
|
if batch:
|
||||||
state_at_timeline_start = (
|
state_at_timeline_start = (
|
||||||
await self.state_storage.get_state_ids_for_event(
|
await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[0].event_id, state_filter=state_filter
|
batch.events[0].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -950,7 +952,7 @@ class SyncHandler:
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
current_state_ids = (
|
current_state_ids = (
|
||||||
await self.state_storage.get_state_ids_for_event(
|
await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[-1].event_id, state_filter=state_filter
|
batch.events[-1].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -982,7 +984,7 @@ class SyncHandler:
|
||||||
# So we fish out all the member events corresponding to the
|
# So we fish out all the member events corresponding to the
|
||||||
# timeline here, and then dedupe any redundant ones below.
|
# timeline here, and then dedupe any redundant ones below.
|
||||||
|
|
||||||
state_ids = await self.state_storage.get_state_ids_for_event(
|
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||||
batch.events[0].event_id,
|
batch.events[0].event_id,
|
||||||
# we only want members!
|
# we only want members!
|
||||||
state_filter=StateFilter.from_types(
|
state_filter=StateFilter.from_types(
|
||||||
|
|
|
@ -221,7 +221,7 @@ class Notifier:
|
||||||
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
|
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
|
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
|
||||||
|
@ -623,7 +623,7 @@ class Notifier:
|
||||||
|
|
||||||
if name == "room":
|
if name == "room":
|
||||||
new_events = await filter_events_for_client(
|
new_events = await filter_events_for_client(
|
||||||
self.storage,
|
self._storage_controllers,
|
||||||
user.to_string(),
|
user.to_string(),
|
||||||
new_events,
|
new_events,
|
||||||
is_peeking=is_peeking,
|
is_peeking=is_peeking,
|
||||||
|
|
|
@ -65,7 +65,7 @@ class HttpPusher(Pusher):
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
|
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
|
||||||
super().__init__(hs, pusher_config)
|
super().__init__(hs, pusher_config)
|
||||||
self.storage = self.hs.get_storage()
|
self._storage_controllers = self.hs.get_storage_controllers()
|
||||||
self.app_display_name = pusher_config.app_display_name
|
self.app_display_name = pusher_config.app_display_name
|
||||||
self.device_display_name = pusher_config.device_display_name
|
self.device_display_name = pusher_config.device_display_name
|
||||||
self.pushkey_ts = pusher_config.ts
|
self.pushkey_ts = pusher_config.ts
|
||||||
|
@ -343,7 +343,9 @@ class HttpPusher(Pusher):
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
|
|
||||||
ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
|
ctx = await push_tools.get_context_for_event(
|
||||||
|
self._storage_controllers, event, self.user_id
|
||||||
|
)
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
"notification": {
|
"notification": {
|
||||||
|
|
|
@ -114,10 +114,10 @@ class Mailer:
|
||||||
|
|
||||||
self.send_email_handler = hs.get_send_email_handler()
|
self.send_email_handler = hs.get_send_email_handler()
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
self.state_storage = self.hs.get_storage().state
|
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||||
self.macaroon_gen = self.hs.get_macaroon_generator()
|
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||||
self.state_handler = self.hs.get_state_handler()
|
self.state_handler = self.hs.get_state_handler()
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.app_name = app_name
|
self.app_name = app_name
|
||||||
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
|
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
|
||||||
|
|
||||||
|
@ -456,7 +456,7 @@ class Mailer:
|
||||||
}
|
}
|
||||||
|
|
||||||
the_events = await filter_events_for_client(
|
the_events = await filter_events_for_client(
|
||||||
self.storage, user_id, results.events_before
|
self._storage_controllers, user_id, results.events_before
|
||||||
)
|
)
|
||||||
the_events.append(notif_event)
|
the_events.append(notif_event)
|
||||||
|
|
||||||
|
@ -494,7 +494,7 @@ class Mailer:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Attempt to check the historical state for the room.
|
# Attempt to check the historical state for the room.
|
||||||
historical_state = await self.state_storage.get_state_for_event(
|
historical_state = await self._state_storage_controller.get_state_for_event(
|
||||||
event.event_id, StateFilter.from_types((type_state_key,))
|
event.event_id, StateFilter.from_types((type_state_key,))
|
||||||
)
|
)
|
||||||
sender_state_event = historical_state.get(type_state_key)
|
sender_state_event = historical_state.get(type_state_key)
|
||||||
|
@ -767,9 +767,11 @@ class Mailer:
|
||||||
member_event_ids.append(sender_state_event_id)
|
member_event_ids.append(sender_state_event_id)
|
||||||
else:
|
else:
|
||||||
# Attempt to check the historical state for the room.
|
# Attempt to check the historical state for the room.
|
||||||
historical_state = await self.state_storage.get_state_for_event(
|
historical_state = (
|
||||||
|
await self._state_storage_controller.get_state_for_event(
|
||||||
event_id, StateFilter.from_types((type_state_key,))
|
event_id, StateFilter.from_types((type_state_key,))
|
||||||
)
|
)
|
||||||
|
)
|
||||||
sender_state_event = historical_state.get(type_state_key)
|
sender_state_event = historical_state.get(type_state_key)
|
||||||
if sender_state_event:
|
if sender_state_event:
|
||||||
member_events[event_id] = sender_state_event
|
member_events[event_id] = sender_state_event
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import Dict
|
||||||
from synapse.api.constants import ReceiptTypes
|
from synapse.api.constants import ReceiptTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||||
from synapse.storage import Storage
|
from synapse.storage.controllers import StorageControllers
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||||
|
|
||||||
|
|
||||||
async def get_context_for_event(
|
async def get_context_for_event(
|
||||||
storage: Storage, ev: EventBase, user_id: str
|
storage: StorageControllers, ev: EventBase, user_id: str
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
ctx = {}
|
ctx = {}
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.federation_event_handler = hs.get_federation_event_handler()
|
self.federation_event_handler = hs.get_federation_event_handler()
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
event.internal_metadata.outlier = event_payload["outlier"]
|
event.internal_metadata.outlier = event_payload["outlier"]
|
||||||
|
|
||||||
context = EventContext.deserialize(
|
context = EventContext.deserialize(
|
||||||
self.storage, event_payload["context"]
|
self._storage_controllers, event_payload["context"]
|
||||||
)
|
)
|
||||||
|
|
||||||
event_and_contexts.append((event, context))
|
event_and_contexts.append((event, context))
|
||||||
|
|
|
@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
|
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
event.internal_metadata.outlier = content["outlier"]
|
event.internal_metadata.outlier = content["outlier"]
|
||||||
|
|
||||||
requester = Requester.deserialize(self.store, content["requester"])
|
requester = Requester.deserialize(self.store, content["requester"])
|
||||||
context = EventContext.deserialize(self.storage, content["context"])
|
context = EventContext.deserialize(
|
||||||
|
self._storage_controllers, content["context"]
|
||||||
|
)
|
||||||
|
|
||||||
ratelimit = content["ratelimit"]
|
ratelimit = content["ratelimit"]
|
||||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||||
|
|
|
@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||||
WorkerServerNoticesSender,
|
WorkerServerNoticesSender,
|
||||||
)
|
)
|
||||||
from synapse.state import StateHandler, StateResolutionHandler
|
from synapse.state import StateHandler, StateResolutionHandler
|
||||||
from synapse.storage import Databases, Storage
|
from synapse.storage import Databases
|
||||||
|
from synapse.storage.controllers import StorageControllers
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
return PasswordPolicyHandler(self)
|
return PasswordPolicyHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_storage(self) -> Storage:
|
def get_storage_controllers(self) -> StorageControllers:
|
||||||
return Storage(self, self.get_datastores())
|
return StorageControllers(self, self.get_datastores())
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_replication_streamer(self) -> ReplicationStreamer:
|
def get_replication_streamer(self) -> ReplicationStreamer:
|
||||||
|
|
|
@ -127,10 +127,10 @@ class StateHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state_storage = hs.get_storage().state
|
self._state_storage_controller = hs.get_storage_controllers().state
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
self._storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
async def get_current_state(
|
async def get_current_state(
|
||||||
|
@ -337,13 +337,15 @@ class StateHandler:
|
||||||
#
|
#
|
||||||
|
|
||||||
if not state_group_before_event:
|
if not state_group_before_event:
|
||||||
state_group_before_event = await self.state_storage.store_state_group(
|
state_group_before_event = (
|
||||||
|
await self._state_storage_controller.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=state_group_before_event_prev_group,
|
prev_group=state_group_before_event_prev_group,
|
||||||
delta_ids=deltas_to_state_group_before_event,
|
delta_ids=deltas_to_state_group_before_event,
|
||||||
current_state_ids=state_ids_before_event,
|
current_state_ids=state_ids_before_event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Assign the new state group to the cached state entry.
|
# Assign the new state group to the cached state entry.
|
||||||
#
|
#
|
||||||
|
@ -359,7 +361,7 @@ class StateHandler:
|
||||||
|
|
||||||
if not event.is_state():
|
if not event.is_state():
|
||||||
return EventContext.with_state(
|
return EventContext.with_state(
|
||||||
storage=self._storage,
|
storage=self._storage_controllers,
|
||||||
state_group_before_event=state_group_before_event,
|
state_group_before_event=state_group_before_event,
|
||||||
state_group=state_group_before_event,
|
state_group=state_group_before_event,
|
||||||
state_delta_due_to_event={},
|
state_delta_due_to_event={},
|
||||||
|
@ -382,16 +384,18 @@ class StateHandler:
|
||||||
state_ids_after_event[key] = event.event_id
|
state_ids_after_event[key] = event.event_id
|
||||||
delta_ids = {key: event.event_id}
|
delta_ids = {key: event.event_id}
|
||||||
|
|
||||||
state_group_after_event = await self.state_storage.store_state_group(
|
state_group_after_event = (
|
||||||
|
await self._state_storage_controller.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=state_group_before_event,
|
prev_group=state_group_before_event,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
current_state_ids=state_ids_after_event,
|
current_state_ids=state_ids_after_event,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return EventContext.with_state(
|
return EventContext.with_state(
|
||||||
storage=self._storage,
|
storage=self._storage_controllers,
|
||||||
state_group=state_group_after_event,
|
state_group=state_group_after_event,
|
||||||
state_group_before_event=state_group_before_event,
|
state_group_before_event=state_group_before_event,
|
||||||
state_delta_due_to_event=delta_ids,
|
state_delta_due_to_event=delta_ids,
|
||||||
|
@ -416,7 +420,9 @@ class StateHandler:
|
||||||
"""
|
"""
|
||||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||||
|
|
||||||
state_groups = await self.state_storage.get_state_group_for_events(event_ids)
|
state_groups = await self._state_storage_controller.get_state_group_for_events(
|
||||||
|
event_ids
|
||||||
|
)
|
||||||
|
|
||||||
state_group_ids = state_groups.values()
|
state_group_ids = state_groups.values()
|
||||||
|
|
||||||
|
@ -424,8 +430,13 @@ class StateHandler:
|
||||||
state_group_ids_set = set(state_group_ids)
|
state_group_ids_set = set(state_group_ids)
|
||||||
if len(state_group_ids_set) == 1:
|
if len(state_group_ids_set) == 1:
|
||||||
(state_group_id,) = state_group_ids_set
|
(state_group_id,) = state_group_ids_set
|
||||||
state = await self.state_storage.get_state_for_groups(state_group_ids_set)
|
state = await self._state_storage_controller.get_state_for_groups(
|
||||||
prev_group, delta_ids = await self.state_storage.get_state_group_delta(
|
state_group_ids_set
|
||||||
|
)
|
||||||
|
(
|
||||||
|
prev_group,
|
||||||
|
delta_ids,
|
||||||
|
) = await self._state_storage_controller.get_state_group_delta(
|
||||||
state_group_id
|
state_group_id
|
||||||
)
|
)
|
||||||
return _StateCacheEntry(
|
return _StateCacheEntry(
|
||||||
|
@ -439,7 +450,7 @@ class StateHandler:
|
||||||
|
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
|
|
||||||
state_to_resolve = await self.state_storage.get_state_for_groups(
|
state_to_resolve = await self._state_storage_controller.get_state_for_groups(
|
||||||
state_group_ids_set
|
state_group_ids_set
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
|
||||||
against different configurations of databases (e.g. single or multiple
|
against different configurations of databases (e.g. single or multiple
|
||||||
databases). The `DatabasePool` class represents connections to a single physical
|
databases). The `DatabasePool` class represents connections to a single physical
|
||||||
database. The `databases` are classes that talk directly to a `DatabasePool`
|
database. The `databases` are classes that talk directly to a `DatabasePool`
|
||||||
instance and have associated schemas, background updates, etc. On top of those
|
instance and have associated schemas, background updates, etc.
|
||||||
there are classes that provide high level interfaces that combine calls to
|
|
||||||
multiple `databases`.
|
On top of the databases are the StorageControllers, located in the
|
||||||
|
`synapse.storage.controllers` module. These classes provide high level
|
||||||
|
interfaces that combine calls to multiple `databases`. They are bundled into the
|
||||||
|
`StorageControllers` singleton for ease of use, and exposed via
|
||||||
|
`HomeServer.get_storage_controllers()`.
|
||||||
|
|
||||||
There are also schemas that get applied to every database, regardless of the
|
There are also schemas that get applied to every database, regardless of the
|
||||||
data stores associated with them (e.g. the schema version tables), which are
|
data stores associated with them (e.g. the schema version tables), which are
|
||||||
stored in `synapse.storage.schema`.
|
stored in `synapse.storage.schema`.
|
||||||
"""
|
"""
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from synapse.storage.databases import Databases
|
from synapse.storage.databases import Databases
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
|
||||||
from synapse.storage.purge_events import PurgeEventsStorage
|
|
||||||
from synapse.storage.state import StateGroupStorage
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from synapse.server import HomeServer
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Databases", "DataStore"]
|
__all__ = ["Databases", "DataStore"]
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
|
||||||
"""The high level interfaces for talking to various storage layers."""
|
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
|
||||||
# We include the main data store here mainly so that we don't have to
|
|
||||||
# rewrite all the existing code to split it into high vs low level
|
|
||||||
# interfaces.
|
|
||||||
self.main = stores.main
|
|
||||||
|
|
||||||
self.purge_events = PurgeEventsStorage(hs, stores)
|
|
||||||
self.state = StateGroupStorage(hs, stores)
|
|
||||||
|
|
||||||
self.persistence = None
|
|
||||||
if stores.persist_events:
|
|
||||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
|
||||||
|
|
46
synapse/storage/controllers/__init__.py
Normal file
46
synapse/storage/controllers/__init__.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from synapse.storage.controllers.persist_events import (
|
||||||
|
EventsPersistenceStorageController,
|
||||||
|
)
|
||||||
|
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
|
||||||
|
from synapse.storage.controllers.state import StateGroupStorageController
|
||||||
|
from synapse.storage.databases import Databases
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Databases", "DataStore"]
|
||||||
|
|
||||||
|
|
||||||
|
class StorageControllers:
|
||||||
|
"""The high level interfaces for talking to various storage controller layers."""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||||
|
# We include the main data store here mainly so that we don't have to
|
||||||
|
# rewrite all the existing code to split it into high vs low level
|
||||||
|
# interfaces.
|
||||||
|
self.main = stores.main
|
||||||
|
|
||||||
|
self.purge_events = PurgeEventsStorageController(hs, stores)
|
||||||
|
self.state = StateGroupStorageController(hs, stores)
|
||||||
|
|
||||||
|
self.persistence = None
|
||||||
|
if stores.persist_events:
|
||||||
|
self.persistence = EventsPersistenceStorageController(hs, stores)
|
|
@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EventsPersistenceStorage:
|
class EventsPersistenceStorageController:
|
||||||
"""High level interface for handling persisting newly received events.
|
"""High level interface for handling persisting newly received events.
|
||||||
|
|
||||||
Takes care of batching up events by room, and calculating the necessary
|
Takes care of batching up events by room, and calculating the necessary
|
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PurgeEventsStorage:
|
class PurgeEventsStorageController:
|
||||||
"""High level interface for purging rooms and event history."""
|
"""High level interface for purging rooms and event history."""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
def __init__(self, hs: "HomeServer", stores: Databases):
|
351
synapse/storage/controllers/state.py
Normal file
351
synapse/storage/controllers/state.py
Normal file
|
@ -0,0 +1,351 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
|
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
||||||
|
from synapse.types import MutableStateMap, StateMap
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases import Databases
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StateGroupStorageController:
|
||||||
|
"""High level interface to fetching state for event."""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||||
|
self._is_mine_id = hs.is_mine_id
|
||||||
|
self.stores = stores
|
||||||
|
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||||
|
|
||||||
|
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||||
|
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||||
|
|
||||||
|
async def get_state_group_delta(
|
||||||
|
self, state_group: int
|
||||||
|
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||||
|
"""Given a state group try to return a previous group and a delta between
|
||||||
|
the old and the new.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_group: The state group used to retrieve state deltas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the previous group and a state map of the event IDs which
|
||||||
|
make up the delta between the old and new state groups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
|
||||||
|
return state_group_delta.prev_group, state_group_delta.delta_ids
|
||||||
|
|
||||||
|
async def get_state_groups_ids(
|
||||||
|
self, _room_id: str, event_ids: Collection[str]
|
||||||
|
) -> Dict[int, MutableStateMap[str]]:
|
||||||
|
"""Get the event IDs of all the state for the state groups for the given events
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_room_id: id of the room for these events
|
||||||
|
event_ids: ids of the events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if we don't have a state group for one or more of the events
|
||||||
|
(ie they are outliers or unknown)
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
|
groups = set(event_to_groups.values())
|
||||||
|
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||||
|
|
||||||
|
return group_to_state
|
||||||
|
|
||||||
|
async def get_state_ids_for_group(
|
||||||
|
self, state_group: int, state_filter: Optional[StateFilter] = None
|
||||||
|
) -> StateMap[str]:
|
||||||
|
"""Get the event IDs of all the state in the given state group
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_group: A state group for which we want to get the state IDs.
|
||||||
|
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolves to a map of (type, state_key) -> event_id
|
||||||
|
"""
|
||||||
|
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
||||||
|
|
||||||
|
return group_to_state[state_group]
|
||||||
|
|
||||||
|
async def get_state_groups(
|
||||||
|
self, room_id: str, event_ids: Collection[str]
|
||||||
|
) -> Dict[int, List[EventBase]]:
|
||||||
|
"""Get the state groups for the given list of event_ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: ID of the room for these events.
|
||||||
|
event_ids: The event IDs to retrieve state for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict of state_group_id -> list of state events.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||||
|
|
||||||
|
state_event_map = await self.stores.main.get_events(
|
||||||
|
[
|
||||||
|
ev_id
|
||||||
|
for group_ids in group_to_ids.values()
|
||||||
|
for ev_id in group_ids.values()
|
||||||
|
],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
group: [
|
||||||
|
state_event_map[v]
|
||||||
|
for v in event_id_map.values()
|
||||||
|
if v in state_event_map
|
||||||
|
]
|
||||||
|
for group, event_id_map in group_to_ids.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_state_groups_from_groups(
|
||||||
|
self, groups: List[int], state_filter: StateFilter
|
||||||
|
) -> Awaitable[Dict[int, StateMap[str]]]:
|
||||||
|
"""Returns the state groups for a given set of groups, filtering on
|
||||||
|
types of state events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups: list of state group IDs to query
|
||||||
|
state_filter: The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of state group to state map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||||
|
|
||||||
|
async def get_state_for_events(
|
||||||
|
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||||
|
) -> Dict[str, StateMap[EventBase]]:
|
||||||
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
|
dicts for each event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids: The events to fetch the state of.
|
||||||
|
state_filter: The state filter used to fetch state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if we don't have a state group for one or more of the events
|
||||||
|
(ie they are outliers or unknown)
|
||||||
|
"""
|
||||||
|
await_full_state = True
|
||||||
|
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||||
|
await_full_state = False
|
||||||
|
|
||||||
|
event_to_groups = await self.get_state_group_for_events(
|
||||||
|
event_ids, await_full_state=await_full_state
|
||||||
|
)
|
||||||
|
|
||||||
|
groups = set(event_to_groups.values())
|
||||||
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
|
groups, state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
state_event_map = await self.stores.main.get_events(
|
||||||
|
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: {
|
||||||
|
k: state_event_map[v]
|
||||||
|
for k, v in group_to_state[group].items()
|
||||||
|
if v in state_event_map
|
||||||
|
}
|
||||||
|
for event_id, group in event_to_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
|
async def get_state_ids_for_events(
|
||||||
|
self,
|
||||||
|
event_ids: Collection[str],
|
||||||
|
state_filter: Optional[StateFilter] = None,
|
||||||
|
) -> Dict[str, StateMap[str]]:
|
||||||
|
"""
|
||||||
|
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||||
|
of the state events (as opposed to the events themselves)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids: events whose state should be returned
|
||||||
|
state_filter: The state filter used to fetch state from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict from event_id -> (type, state_key) -> event_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if we don't have a state group for one or more of the events
|
||||||
|
(ie they are outliers or unknown)
|
||||||
|
"""
|
||||||
|
await_full_state = True
|
||||||
|
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||||
|
await_full_state = False
|
||||||
|
|
||||||
|
event_to_groups = await self.get_state_group_for_events(
|
||||||
|
event_ids, await_full_state=await_full_state
|
||||||
|
)
|
||||||
|
|
||||||
|
groups = set(event_to_groups.values())
|
||||||
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
|
groups, state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: group_to_state[group]
|
||||||
|
for event_id, group in event_to_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
|
async def get_state_for_event(
|
||||||
|
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||||
|
) -> StateMap[EventBase]:
|
||||||
|
"""
|
||||||
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: event whose state should be returned
|
||||||
|
state_filter: The state filter used to fetch state from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict from (type, state_key) -> state_event
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if we don't have a state group for the event (ie it is an
|
||||||
|
outlier or is unknown)
|
||||||
|
"""
|
||||||
|
state_map = await self.get_state_for_events(
|
||||||
|
[event_id], state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
return state_map[event_id]
|
||||||
|
|
||||||
|
async def get_state_ids_for_event(
|
||||||
|
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||||
|
) -> StateMap[str]:
|
||||||
|
"""
|
||||||
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: event whose state should be returned
|
||||||
|
state_filter: The state filter used to fetch state from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict from (type, state_key) -> state_event_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if we don't have a state group for the event (ie it is an
|
||||||
|
outlier or is unknown)
|
||||||
|
"""
|
||||||
|
state_map = await self.get_state_ids_for_events(
|
||||||
|
[event_id], state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
return state_map[event_id]
|
||||||
|
|
||||||
|
def get_state_for_groups(
|
||||||
|
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||||
|
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||||
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
|
filtering by type/state_key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups: list of state groups for which we want to get the state.
|
||||||
|
state_filter: The state filter used to fetch state.
|
||||||
|
from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of state group to state map.
|
||||||
|
"""
|
||||||
|
return self.stores.state._get_state_for_groups(
|
||||||
|
groups, state_filter or StateFilter.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_state_group_for_events(
|
||||||
|
self,
|
||||||
|
event_ids: Collection[str],
|
||||||
|
await_full_state: bool = True,
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
"""Returns mapping event_id -> state_group
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids: events to get state groups for
|
||||||
|
await_full_state: if true, will block if we do not yet have complete
|
||||||
|
state at these events.
|
||||||
|
"""
|
||||||
|
if await_full_state:
|
||||||
|
await self._partial_state_events_tracker.await_full_state(event_ids)
|
||||||
|
|
||||||
|
return await self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
|
async def store_state_group(
|
||||||
|
self,
|
||||||
|
event_id: str,
|
||||||
|
room_id: str,
|
||||||
|
prev_group: Optional[int],
|
||||||
|
delta_ids: Optional[StateMap[str]],
|
||||||
|
current_state_ids: StateMap[str],
|
||||||
|
) -> int:
|
||||||
|
"""Store a new set of state, returning a newly assigned state group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The event ID for which the state was calculated.
|
||||||
|
room_id: ID of the room for which the state was calculated.
|
||||||
|
prev_group: A previous state group for the room, optional.
|
||||||
|
delta_ids: The delta between state at `prev_group` and
|
||||||
|
`current_state_ids`, if `prev_group` was given. Same format as
|
||||||
|
`current_state_ids`.
|
||||||
|
current_state_ids: The state to store. Map of (type, state_key)
|
||||||
|
to event_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state group ID
|
||||||
|
"""
|
||||||
|
return await self.stores.state.store_state_group(
|
||||||
|
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||||
|
)
|
|
@ -15,7 +15,6 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -32,15 +31,11 @@ import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
|
||||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
|
||||||
from synapse.types import MutableStateMap, StateKey, StateMap
|
from synapse.types import MutableStateMap, StateKey, StateMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
|
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
|
||||||
|
|
||||||
from synapse.server import HomeServer
|
|
||||||
from synapse.storage.databases import Databases
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|
||||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||||
)
|
)
|
||||||
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
|
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
|
||||||
|
|
||||||
|
|
||||||
class StateGroupStorage:
|
|
||||||
"""High level interface to fetching state for event."""
|
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
|
||||||
self._is_mine_id = hs.is_mine_id
|
|
||||||
self.stores = stores
|
|
||||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
|
||||||
|
|
||||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
|
||||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
|
||||||
|
|
||||||
async def get_state_group_delta(
|
|
||||||
self, state_group: int
|
|
||||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
|
||||||
"""Given a state group try to return a previous group and a delta between
|
|
||||||
the old and the new.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_group: The state group used to retrieve state deltas.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple of the previous group and a state map of the event IDs which
|
|
||||||
make up the delta between the old and new state groups.
|
|
||||||
"""
|
|
||||||
|
|
||||||
state_group_delta = await self.stores.state.get_state_group_delta(state_group)
|
|
||||||
return state_group_delta.prev_group, state_group_delta.delta_ids
|
|
||||||
|
|
||||||
async def get_state_groups_ids(
|
|
||||||
self, _room_id: str, event_ids: Collection[str]
|
|
||||||
) -> Dict[int, MutableStateMap[str]]:
|
|
||||||
"""Get the event IDs of all the state for the state groups for the given events
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_room_id: id of the room for these events
|
|
||||||
event_ids: ids of the events
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError if we don't have a state group for one or more of the events
|
|
||||||
(ie they are outliers or unknown)
|
|
||||||
"""
|
|
||||||
if not event_ids:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
event_to_groups = await self.get_state_group_for_events(event_ids)
|
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
|
||||||
|
|
||||||
return group_to_state
|
|
||||||
|
|
||||||
async def get_state_ids_for_group(
|
|
||||||
self, state_group: int, state_filter: Optional[StateFilter] = None
|
|
||||||
) -> StateMap[str]:
|
|
||||||
"""Get the event IDs of all the state in the given state group
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_group: A state group for which we want to get the state IDs.
|
|
||||||
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Resolves to a map of (type, state_key) -> event_id
|
|
||||||
"""
|
|
||||||
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
|
|
||||||
|
|
||||||
return group_to_state[state_group]
|
|
||||||
|
|
||||||
async def get_state_groups(
|
|
||||||
self, room_id: str, event_ids: Collection[str]
|
|
||||||
) -> Dict[int, List[EventBase]]:
|
|
||||||
"""Get the state groups for the given list of event_ids
|
|
||||||
|
|
||||||
Args:
|
|
||||||
room_id: ID of the room for these events.
|
|
||||||
event_ids: The event IDs to retrieve state for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict of state_group_id -> list of state events.
|
|
||||||
"""
|
|
||||||
if not event_ids:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
|
||||||
|
|
||||||
state_event_map = await self.stores.main.get_events(
|
|
||||||
[
|
|
||||||
ev_id
|
|
||||||
for group_ids in group_to_ids.values()
|
|
||||||
for ev_id in group_ids.values()
|
|
||||||
],
|
|
||||||
get_prev_content=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
group: [
|
|
||||||
state_event_map[v]
|
|
||||||
for v in event_id_map.values()
|
|
||||||
if v in state_event_map
|
|
||||||
]
|
|
||||||
for group, event_id_map in group_to_ids.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_state_groups_from_groups(
|
|
||||||
self, groups: List[int], state_filter: StateFilter
|
|
||||||
) -> Awaitable[Dict[int, StateMap[str]]]:
|
|
||||||
"""Returns the state groups for a given set of groups, filtering on
|
|
||||||
types of state events.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
groups: list of state group IDs to query
|
|
||||||
state_filter: The state filter used to fetch state
|
|
||||||
from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict of state group to state map.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
|
||||||
|
|
||||||
async def get_state_for_events(
|
|
||||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
|
||||||
) -> Dict[str, StateMap[EventBase]]:
|
|
||||||
"""Given a list of event_ids and type tuples, return a list of state
|
|
||||||
dicts for each event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids: The events to fetch the state of.
|
|
||||||
state_filter: The state filter used to fetch state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict of (event_id) -> (type, state_key) -> [state_events]
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError if we don't have a state group for one or more of the events
|
|
||||||
(ie they are outliers or unknown)
|
|
||||||
"""
|
|
||||||
await_full_state = True
|
|
||||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
|
||||||
await_full_state = False
|
|
||||||
|
|
||||||
event_to_groups = await self.get_state_group_for_events(
|
|
||||||
event_ids, await_full_state=await_full_state
|
|
||||||
)
|
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
|
||||||
groups, state_filter or StateFilter.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
state_event_map = await self.stores.main.get_events(
|
|
||||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
|
||||||
get_prev_content=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
event_to_state = {
|
|
||||||
event_id: {
|
|
||||||
k: state_event_map[v]
|
|
||||||
for k, v in group_to_state[group].items()
|
|
||||||
if v in state_event_map
|
|
||||||
}
|
|
||||||
for event_id, group in event_to_groups.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
|
||||||
|
|
||||||
async def get_state_ids_for_events(
|
|
||||||
self,
|
|
||||||
event_ids: Collection[str],
|
|
||||||
state_filter: Optional[StateFilter] = None,
|
|
||||||
) -> Dict[str, StateMap[str]]:
|
|
||||||
"""
|
|
||||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
|
||||||
of the state events (as opposed to the events themselves)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids: events whose state should be returned
|
|
||||||
state_filter: The state filter used to fetch state from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict from event_id -> (type, state_key) -> event_id
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError if we don't have a state group for one or more of the events
|
|
||||||
(ie they are outliers or unknown)
|
|
||||||
"""
|
|
||||||
await_full_state = True
|
|
||||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
|
||||||
await_full_state = False
|
|
||||||
|
|
||||||
event_to_groups = await self.get_state_group_for_events(
|
|
||||||
event_ids, await_full_state=await_full_state
|
|
||||||
)
|
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
|
||||||
group_to_state = await self.stores.state._get_state_for_groups(
|
|
||||||
groups, state_filter or StateFilter.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
event_to_state = {
|
|
||||||
event_id: group_to_state[group]
|
|
||||||
for event_id, group in event_to_groups.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
|
||||||
|
|
||||||
async def get_state_for_event(
|
|
||||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
|
||||||
) -> StateMap[EventBase]:
|
|
||||||
"""
|
|
||||||
Get the state dict corresponding to a particular event
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id: event whose state should be returned
|
|
||||||
state_filter: The state filter used to fetch state from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict from (type, state_key) -> state_event
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError if we don't have a state group for the event (ie it is an
|
|
||||||
outlier or is unknown)
|
|
||||||
"""
|
|
||||||
state_map = await self.get_state_for_events(
|
|
||||||
[event_id], state_filter or StateFilter.all()
|
|
||||||
)
|
|
||||||
return state_map[event_id]
|
|
||||||
|
|
||||||
async def get_state_ids_for_event(
|
|
||||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
|
||||||
) -> StateMap[str]:
|
|
||||||
"""
|
|
||||||
Get the state dict corresponding to a particular event
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id: event whose state should be returned
|
|
||||||
state_filter: The state filter used to fetch state from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dict from (type, state_key) -> state_event_id
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError if we don't have a state group for the event (ie it is an
|
|
||||||
outlier or is unknown)
|
|
||||||
"""
|
|
||||||
state_map = await self.get_state_ids_for_events(
|
|
||||||
[event_id], state_filter or StateFilter.all()
|
|
||||||
)
|
|
||||||
return state_map[event_id]
|
|
||||||
|
|
||||||
def get_state_for_groups(
|
|
||||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
|
||||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
|
||||||
"""Gets the state at each of a list of state groups, optionally
|
|
||||||
filtering by type/state_key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
groups: list of state groups for which we want to get the state.
|
|
||||||
state_filter: The state filter used to fetch state.
|
|
||||||
from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict of state group to state map.
|
|
||||||
"""
|
|
||||||
return self.stores.state._get_state_for_groups(
|
|
||||||
groups, state_filter or StateFilter.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_state_group_for_events(
|
|
||||||
self,
|
|
||||||
event_ids: Collection[str],
|
|
||||||
await_full_state: bool = True,
|
|
||||||
) -> Mapping[str, int]:
|
|
||||||
"""Returns mapping event_id -> state_group
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_ids: events to get state groups for
|
|
||||||
await_full_state: if true, will block if we do not yet have complete
|
|
||||||
state at these events.
|
|
||||||
"""
|
|
||||||
if await_full_state:
|
|
||||||
await self._partial_state_events_tracker.await_full_state(event_ids)
|
|
||||||
|
|
||||||
return await self.stores.main._get_state_group_for_events(event_ids)
|
|
||||||
|
|
||||||
async def store_state_group(
|
|
||||||
self,
|
|
||||||
event_id: str,
|
|
||||||
room_id: str,
|
|
||||||
prev_group: Optional[int],
|
|
||||||
delta_ids: Optional[StateMap[str]],
|
|
||||||
current_state_ids: StateMap[str],
|
|
||||||
) -> int:
|
|
||||||
"""Store a new set of state, returning a newly assigned state group.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_id: The event ID for which the state was calculated.
|
|
||||||
room_id: ID of the room for which the state was calculated.
|
|
||||||
prev_group: A previous state group for the room, optional.
|
|
||||||
delta_ids: The delta between state at `prev_group` and
|
|
||||||
`current_state_ids`, if `prev_group` was given. Same format as
|
|
||||||
`current_state_ids`.
|
|
||||||
current_state_ids: The state to store. Map of (type, state_key)
|
|
||||||
to event_id.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The state group ID
|
|
||||||
"""
|
|
||||||
return await self.stores.state.store_state_group(
|
|
||||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
|
||||||
)
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing_extensions import Final
|
||||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.storage import Storage
|
from synapse.storage.controllers import StorageControllers
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
|
||||||
|
|
||||||
|
|
||||||
async def filter_events_for_client(
|
async def filter_events_for_client(
|
||||||
storage: Storage,
|
storage: StorageControllers,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
is_peeking: bool = False,
|
is_peeking: bool = False,
|
||||||
|
@ -268,7 +268,7 @@ async def filter_events_for_client(
|
||||||
|
|
||||||
|
|
||||||
async def filter_events_for_server(
|
async def filter_events_for_server(
|
||||||
storage: Storage,
|
storage: StorageControllers,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
redact: bool = True,
|
redact: bool = True,
|
||||||
|
@ -360,7 +360,7 @@ async def filter_events_for_server(
|
||||||
|
|
||||||
|
|
||||||
async def _event_to_history_vis(
|
async def _event_to_history_vis(
|
||||||
storage: Storage, events: Collection[EventBase]
|
storage: StorageControllers, events: Collection[EventBase]
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Get the history visibility at each of the given events
|
"""Get the history visibility at each of the given events
|
||||||
|
|
||||||
|
@ -407,7 +407,7 @@ async def _event_to_history_vis(
|
||||||
|
|
||||||
|
|
||||||
async def _event_to_memberships(
|
async def _event_to_memberships(
|
||||||
storage: Storage, events: Collection[EventBase], server_name: str
|
storage: StorageControllers, events: Collection[EventBase], server_name: str
|
||||||
) -> Dict[str, StateMap[EventBase]]:
|
) -> Dict[str, StateMap[EventBase]]:
|
||||||
"""Get the remote membership list at each of the given events
|
"""Get the remote membership list at each of the given events
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self.user_id = self.register_user("u1", "pass")
|
self.user_id = self.register_user("u1", "pass")
|
||||||
self.user_tok = self.login("u1", "pass")
|
self.user_tok = self.login("u1", "pass")
|
||||||
|
@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||||
def _check_serialize_deserialize(self, event, context):
|
def _check_serialize_deserialize(self, event, context):
|
||||||
serialized = self.get_success(context.serialize(event, self.store))
|
serialized = self.get_success(context.serialize(event, self.store))
|
||||||
|
|
||||||
d_context = EventContext.deserialize(self.storage, serialized)
|
d_context = EventContext.deserialize(self._storage_controllers, serialized)
|
||||||
|
|
||||||
self.assertEqual(context.state_group, d_context.state_group)
|
self.assertEqual(context.state_group, d_context.state_group)
|
||||||
self.assertEqual(context.rejected, d_context.rejected)
|
self.assertEqual(context.rejected, d_context.rejected)
|
||||||
|
|
|
@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
hs = self.setup_test_homeserver(federation_http_client=None)
|
hs = self.setup_test_homeserver(federation_http_client=None)
|
||||||
self.handler = hs.get_federation_handler()
|
self.handler = hs.get_federation_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.state_storage = hs.get_storage().state
|
self.state_storage_controller = hs.get_storage_controllers().state
|
||||||
self._event_auth_handler = hs.get_event_auth_handler()
|
self._event_auth_handler = hs.get_event_auth_handler()
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
# mapping from (type, state_key) -> state_event_id
|
# mapping from (type, state_key) -> state_event_id
|
||||||
assert most_recent_prev_event_id is not None
|
assert most_recent_prev_event_id is not None
|
||||||
prev_state_map = self.get_success(
|
prev_state_map = self.get_success(
|
||||||
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
|
self.state_storage_controller.get_state_ids_for_event(
|
||||||
|
most_recent_prev_event_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# List of state event ID's
|
# List of state event ID's
|
||||||
prev_state_ids = list(prev_state_map.values())
|
prev_state_ids = list(prev_state_map.values())
|
||||||
|
|
|
@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
) -> None:
|
) -> None:
|
||||||
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
|
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
|
||||||
main_store = self.hs.get_datastores().main
|
main_store = self.hs.get_datastores().main
|
||||||
state_storage = self.hs.get_storage().state
|
state_storage_controller = self.hs.get_storage_controllers().state
|
||||||
|
|
||||||
# create the room
|
# create the room
|
||||||
user_id = self.register_user("kermit", "test")
|
user_id = self.register_user("kermit", "test")
|
||||||
|
@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
if prev_exists_as_outlier:
|
if prev_exists_as_outlier:
|
||||||
prev_event.internal_metadata.outlier = True
|
prev_event.internal_metadata.outlier = True
|
||||||
persistence = self.hs.get_storage().persistence
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.get_success(
|
self.get_success(
|
||||||
persistence.persist_event(
|
persistence.persist_event(
|
||||||
prev_event, EventContext.for_outlier(self.hs.get_storage())
|
prev_event,
|
||||||
|
EventContext.for_outlier(self.hs.get_storage_controllers()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# check that the state at that event is as expected
|
# check that the state at that event is as expected
|
||||||
state = self.get_success(
|
state = self.get_success(
|
||||||
state_storage.get_state_ids_for_event(pulled_event.event_id)
|
state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
|
||||||
)
|
)
|
||||||
expected_state = {
|
expected_state = {
|
||||||
(e.type, e.state_key): e.event_id for e in state_at_prev_event
|
(e.type, e.state_key): e.event_id for e in state_at_prev_event
|
||||||
|
|
|
@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.handler = self.hs.get_event_creation_handler()
|
self.handler = self.hs.get_event_creation_handler()
|
||||||
self.persist_event_storage = self.hs.get_storage().persistence
|
self._persist_event_storage_controller = (
|
||||||
|
self.hs.get_storage_controllers().persistence
|
||||||
|
)
|
||||||
|
|
||||||
self.user_id = self.register_user("tester", "foobar")
|
self.user_id = self.register_user("tester", "foobar")
|
||||||
self.access_token = self.login("tester", "foobar")
|
self.access_token = self.login("tester", "foobar")
|
||||||
|
@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
self._persist_event_storage_controller.persist_event(
|
||||||
|
memberEvent, memberEventContext
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return memberEvent, memberEventContext
|
return memberEvent, memberEventContext
|
||||||
|
@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||||
|
|
||||||
ret_event3, event_pos3, _ = self.get_success(
|
ret_event3, event_pos3, _ = self.get_success(
|
||||||
self.persist_event_storage.persist_event(event3, context)
|
self._persist_event_storage_controller.persist_event(event3, context)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that the returned values match those from the initial event
|
# Assert that the returned values match those from the initial event
|
||||||
|
@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||||
|
|
||||||
events, _ = self.get_success(
|
events, _ = self.get_success(
|
||||||
self.persist_event_storage.persist_events([(event3, context)])
|
self._persist_event_storage_controller.persist_events([(event3, context)])
|
||||||
)
|
)
|
||||||
ret_event4 = events[0]
|
ret_event4 = events[0]
|
||||||
|
|
||||||
|
@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||||
|
|
||||||
events, _ = self.get_success(
|
events, _ = self.get_success(
|
||||||
self.persist_event_storage.persist_events(
|
self._persist_event_storage_controller.persist_events(
|
||||||
[(event1, context1), (event2, context2)]
|
[(event1, context1), (event2, context2)]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_storage().persistence.persist_event(event, context)
|
self.hs.get_storage_controllers().persistence.persist_event(event, context)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
||||||
|
|
|
@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
self.master_store = hs.get_datastores().main
|
self.master_store = hs.get_datastores().main
|
||||||
self.slaved_store = self.worker_hs.get_datastores().main
|
self.slaved_store = self.worker_hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
"""Tell the master side of replication that something has happened, and then
|
"""Tell the master side of replication that something has happened, and then
|
||||||
|
|
|
@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
msg, msgctx = self.build_event()
|
msg, msgctx = self.build_event()
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
|
self._storage_controllers.persistence.persist_events(
|
||||||
|
[(j2, j2ctx), (msg, msgctx)]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
|
@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
if backfill:
|
if backfill:
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.persistence.persist_events(
|
self._storage_controllers.persistence.persist_events(
|
||||||
[(event, context)], backfilled=True
|
[(event, context)], backfilled=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(
|
||||||
|
self._storage_controllers.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
super().prepare(reactor, clock, homeserver)
|
super().prepare(reactor, clock, homeserver)
|
||||||
self.room_creator = homeserver.get_room_creation_handler()
|
self.room_creator = homeserver.get_room_creation_handler()
|
||||||
self.persist_event_storage = self.hs.get_storage().persistence
|
self.persist_event_storage_controller = (
|
||||||
|
self.hs.get_storage_controllers().persistence
|
||||||
|
)
|
||||||
|
|
||||||
# Create a test user
|
# Create a test user
|
||||||
self.ourUser = UserID.from_string(OUR_USER_ID)
|
self.ourUser = UserID.from_string(OUR_USER_ID)
|
||||||
|
@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
self.persist_event_storage_controller.persist_event(
|
||||||
|
memberEvent, memberEventContext
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Join the second user to the second room
|
# Join the second user to the second room
|
||||||
|
@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
self.persist_event_storage_controller.persist_event(
|
||||||
|
memberEvent, memberEventContext
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_return_empty_with_no_data(self):
|
def test_return_empty_with_no_data(self):
|
||||||
|
|
|
@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
event_builder_factory = self.hs.get_event_builder_factory()
|
event_builder_factory = self.hs.get_event_builder_factory()
|
||||||
event_creation_handler = self.hs.get_event_creation_handler()
|
event_creation_handler = self.hs.get_event_creation_handler()
|
||||||
storage = self.hs.get_storage()
|
storage_controllers = self.hs.get_storage_controllers()
|
||||||
|
|
||||||
# Create two rooms, one with a local user only and one with both a local
|
# Create two rooms, one with a local user only and one with both a local
|
||||||
# and remote user.
|
# and remote user.
|
||||||
|
@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
event_creation_handler.create_new_client_event(builder)
|
event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(storage.persistence.persist_event(event, context))
|
self.get_success(storage_controllers.persistence.persist_event(event, context))
|
||||||
|
|
||||||
# Now get rooms
|
# Now get rooms
|
||||||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
||||||
|
|
|
@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
We do this by setting a very long time between purge jobs.
|
We do this by setting a very long time between purge jobs.
|
||||||
"""
|
"""
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
storage = self.hs.get_storage()
|
storage_controllers = self.hs.get_storage_controllers()
|
||||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||||
|
|
||||||
# Send a first event, which should be filtered out at the end of the test.
|
# Send a first event, which should be filtered out at the end of the test.
|
||||||
|
@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(2, len(events), "events retrieved from database")
|
self.assertEqual(2, len(events), "events retrieved from database")
|
||||||
filtered_events = self.get_success(
|
filtered_events = self.get_success(
|
||||||
filter_events_for_client(storage, self.user_id, events)
|
filter_events_for_client(storage_controllers, self.user_id, events)
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should only get one event back.
|
# We should only get one event back.
|
||||||
|
|
|
@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.storage = hs.get_storage()
|
self._storage_controllers = hs.get_storage_controllers()
|
||||||
|
|
||||||
self.virtual_user_id, _ = self.register_appservice_user(
|
self.virtual_user_id, _ = self.register_appservice_user(
|
||||||
"as_user_potato", self.appservice.token
|
"as_user_potato", self.appservice.token
|
||||||
|
@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Fetch the state_groups
|
# Fetch the state_groups
|
||||||
state_group_map = self.get_success(
|
state_group_map = self.get_success(
|
||||||
self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
|
self._storage_controllers.state.get_state_groups_ids(
|
||||||
|
room_id, historical_event_ids
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# We expect all of the historical events to be using the same state_group
|
# We expect all of the historical events to be using the same state_group
|
||||||
|
|
|
@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
# We need to persist the events to the events and state_events
|
# We need to persist the events to the events and state_events
|
||||||
# tables.
|
# tables.
|
||||||
persist_events_store._store_event_txn(
|
persist_events_store._store_event_txn(
|
||||||
txn, [(e, EventContext(self.hs.get_storage())) for e in events]
|
txn,
|
||||||
|
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually call the function that calculates the auth chain stuff.
|
# Actually call the function that calculates the auth chain stuff.
|
||||||
|
|
|
@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self.persistence = self.hs.get_storage().persistence
|
self._persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
self.register_user("user", "pass")
|
self.register_user("user", "pass")
|
||||||
|
@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
context = self.get_success(
|
context = self.get_success(
|
||||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||||
)
|
)
|
||||||
self.get_success(self.persistence.persist_event(event, context))
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
|
|
||||||
def assert_extremities(self, expected_extremities):
|
def assert_extremities(self, expected_extremities):
|
||||||
"""Assert the current extremities for the room"""
|
"""Assert the current extremities for the room"""
|
||||||
|
@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.persistence.persist_event(remote_event_2, context))
|
self.get_success(self._persistence.persist_event(remote_event_2, context))
|
||||||
|
|
||||||
# Check that we haven't dropped the old extremity.
|
# Check that we haven't dropped the old extremity.
|
||||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||||
|
@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def prepare(self, reactor, clock, homeserver):
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self.persistence = self.hs.get_storage().persistence
|
self._persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
def test_remote_user_rooms_cache_invalidated(self):
|
def test_remote_user_rooms_cache_invalidated(self):
|
||||||
|
@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||||
|
|
||||||
# Call `get_rooms_for_user` to add the remote user to the cache
|
# Call `get_rooms_for_user` to add the remote user to the cache
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
||||||
|
@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||||
|
|
||||||
# Call `get_users_in_room` to add the remote user to the cache
|
# Call `get_users_in_room` to add the remote user to the cache
|
||||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||||
|
|
|
@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
self.room_id = self.helper.create_room_as(self.user_id)
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = self.hs.get_storage()
|
self._storage_controllers = self.hs.get_storage_controllers()
|
||||||
|
|
||||||
def test_purge_history(self):
|
def test_purge_history(self):
|
||||||
"""
|
"""
|
||||||
|
@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.purge_events.purge_history(self.room_id, token_str, True)
|
self._storage_controllers.purge_events.purge_history(
|
||||||
|
self.room_id, token_str, True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||||
|
@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
f = self.get_failure(
|
f = self.get_failure(
|
||||||
self.storage.purge_events.purge_history(self.room_id, event, True),
|
self._storage_controllers.purge_events.purge_history(
|
||||||
|
self.room_id, event, True
|
||||||
|
),
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
self.assertIn("greater than forward", f.value.args[0])
|
self.assertIn("greater than forward", f.value.args[0])
|
||||||
|
@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
|
||||||
self.assertIsNotNone(create_event)
|
self.assertIsNotNone(create_event)
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
self.get_success(self.storage.purge_events.purge_room(self.room_id))
|
self.get_success(
|
||||||
|
self._storage_controllers.purge_events.purge_room(self.room_id)
|
||||||
|
)
|
||||||
|
|
||||||
# The events aren't found.
|
# The events aren't found.
|
||||||
self.store._invalidate_get_event_cache(create_event.event_id)
|
self.store._invalidate_get_event_cache(create_event.event_id)
|
||||||
|
|
|
@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage = hs.get_storage_controllers()
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event_1, context_1))
|
self.get_success(self._storage.persistence.persist_event(event_1, context_1))
|
||||||
|
|
||||||
event_2, context_2 = self.get_success(
|
event_2, context_2 = self.get_success(
|
||||||
self.event_creation_handler.create_new_client_event(
|
self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.get_success(self.storage.persistence.persist_event(event_2, context_2))
|
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
|
||||||
|
|
||||||
# fetch one of the redactions
|
# fetch one of the redactions
|
||||||
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
||||||
|
@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.persistence.persist_event(redaction_event, context)
|
self._storage.persistence.persist_event(redaction_event, context)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now lets jump to the future where we have censored the redaction event
|
# Now lets jump to the future where we have censored the redaction event
|
||||||
|
|
|
@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
# Room events need the full datastore, for persist_event() and
|
# Room events need the full datastore, for persist_event() and
|
||||||
# get_room_state()
|
# get_room_state()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self._storage = hs.get_storage_controllers()
|
||||||
self.event_factory = hs.get_event_factory()
|
self.event_factory = hs.get_event_factory()
|
||||||
|
|
||||||
self.room = RoomID.from_string("!abcde:test")
|
self.room = RoomID.from_string("!abcde:test")
|
||||||
|
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def inject_room_event(self, **kwargs):
|
def inject_room_event(self, **kwargs):
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.persistence.persist_event(
|
self._storage.persistence.persist_event(
|
||||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
||||||
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
|
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
|
||||||
prev_state_map = self.get_success(
|
prev_state_map = self.get_success(
|
||||||
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
|
self.hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||||
|
prev_event_ids[0]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
event_dict = {
|
event_dict = {
|
||||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||||
class StateStoreTestCase(HomeserverTestCase):
|
class StateStoreTestCase(HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage_controllers()
|
||||||
self.state_datastore = self.storage.state.stores.state
|
self.state_datastore = self.storage.state.stores.state
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
self.event_creation_handler = hs.get_event_creation_handler()
|
self.event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
|
@ -179,12 +179,12 @@ class Graph:
|
||||||
class StateTestCase(unittest.TestCase):
|
class StateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.dummy_store = _DummyStore()
|
self.dummy_store = _DummyStore()
|
||||||
storage = Mock(main=self.dummy_store, state=self.dummy_store)
|
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||||
hs = Mock(
|
hs = Mock(
|
||||||
spec_set=[
|
spec_set=[
|
||||||
"config",
|
"config",
|
||||||
"get_datastores",
|
"get_datastores",
|
||||||
"get_storage",
|
"get_storage_controllers",
|
||||||
"get_auth",
|
"get_auth",
|
||||||
"get_state_handler",
|
"get_state_handler",
|
||||||
"get_clock",
|
"get_clock",
|
||||||
|
@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
||||||
hs.get_storage.return_value = storage
|
hs.get_storage_controllers.return_value = storage_controllers
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
|
@ -70,7 +70,7 @@ async def inject_event(
|
||||||
"""
|
"""
|
||||||
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
|
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||||
|
|
||||||
persistence = hs.get_storage().persistence
|
persistence = hs.get_storage_controllers().persistence
|
||||||
assert persistence is not None
|
assert persistence is not None
|
||||||
|
|
||||||
await persistence.persist_event(event, context)
|
await persistence.persist_event(event, context)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
super(FilterEventsForServerTestCase, self).setUp()
|
super(FilterEventsForServerTestCase, self).setUp()
|
||||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||||
self.storage = self.hs.get_storage()
|
self._storage_controllers = self.hs.get_storage_controllers()
|
||||||
|
|
||||||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
||||||
|
|
||||||
|
@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
events_to_filter.append(evt)
|
events_to_filter.append(evt)
|
||||||
|
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
filter_events_for_server(
|
||||||
|
self._storage_controllers, "test_server", events_to_filter
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# the result should be 5 redacted events, and 5 unredacted events.
|
# the result should be 5 redacted events, and 5 unredacted events.
|
||||||
|
@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
outlier = self._inject_outlier()
|
outlier = self._inject_outlier()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.get_success(
|
self.get_success(
|
||||||
filter_events_for_server(self.storage, "remote_hs", [outlier])
|
filter_events_for_server(
|
||||||
|
self._storage_controllers, "remote_hs", [outlier]
|
||||||
|
)
|
||||||
),
|
),
|
||||||
[outlier],
|
[outlier],
|
||||||
)
|
)
|
||||||
|
@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
evt = self._inject_message("@unerased:local_hs")
|
evt = self._inject_message("@unerased:local_hs")
|
||||||
|
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
|
filter_events_for_server(
|
||||||
|
self._storage_controllers, "remote_hs", [outlier, evt]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
||||||
self.assertEqual(filtered[0], outlier)
|
self.assertEqual(filtered[0], outlier)
|
||||||
|
@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
# ... but other servers should only be able to see the outlier (the other should
|
# ... but other servers should only be able to see the outlier (the other should
|
||||||
# be redacted)
|
# be redacted)
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(self.storage, "other_server", [outlier, evt])
|
filter_events_for_server(
|
||||||
|
self._storage_controllers, "other_server", [outlier, evt]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(filtered[0], outlier)
|
self.assertEqual(filtered[0], outlier)
|
||||||
self.assertEqual(filtered[1].event_id, evt.event_id)
|
self.assertEqual(filtered[1].event_id, evt.event_id)
|
||||||
|
@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# ... and the filtering happens.
|
# ... and the filtering happens.
|
||||||
filtered = self.get_success(
|
filtered = self.get_success(
|
||||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
filter_events_for_server(
|
||||||
|
self._storage_controllers, "test_server", events_to_filter
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(0, len(events_to_filter)):
|
for i in range(0, len(events_to_filter)):
|
||||||
|
@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
event, context = self.get_success(
|
event, context = self.get_success(
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(
|
||||||
|
self._storage_controllers.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_room_member(
|
def _inject_room_member(
|
||||||
|
@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(
|
||||||
|
self._storage_controllers.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_message(
|
def _inject_message(
|
||||||
|
@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
self.get_success(
|
||||||
|
self._storage_controllers.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_outlier(self) -> EventBase:
|
def _inject_outlier(self) -> EventBase:
|
||||||
|
@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.storage.persistence.persist_event(
|
self._storage_controllers.persistence.persist_event(
|
||||||
event, EventContext.for_outlier(self.storage)
|
event, EventContext.for_outlier(self._storage_controllers)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return event
|
return event
|
||||||
|
@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.get_success(
|
self.get_success(
|
||||||
filter_events_for_client(
|
filter_events_for_client(
|
||||||
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
|
self.hs.get_storage_controllers(),
|
||||||
|
"@user:test",
|
||||||
|
[invite_event, reject_event],
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
[invite_event, reject_event],
|
[invite_event, reject_event],
|
||||||
|
@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.get_success(
|
self.get_success(
|
||||||
filter_events_for_client(
|
filter_events_for_client(
|
||||||
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
|
self.hs.get_storage_controllers(),
|
||||||
|
"@other:test",
|
||||||
|
[invite_event, reject_event],
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
[],
|
[],
|
||||||
|
|
|
@ -264,7 +264,7 @@ class MockClock:
|
||||||
async def create_room(hs, room_id: str, creator_id: str):
|
async def create_room(hs, room_id: str, creator_id: str):
|
||||||
"""Creates and persist a creation event for the given room"""
|
"""Creates and persist a creation event for the given room"""
|
||||||
|
|
||||||
persistence_store = hs.get_storage().persistence
|
persistence_store = hs.get_storage_controllers().persistence
|
||||||
store = hs.get_datastores().main
|
store = hs.get_datastores().main
|
||||||
event_builder_factory = hs.get_event_builder_factory()
|
event_builder_factory = hs.get_event_builder_factory()
|
||||||
event_creation_handler = hs.get_event_creation_handler()
|
event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
Loading…
Reference in a new issue