forked from MirrorHub/synapse
Faster joins: persist to database (#12012)
When we get a partial_state response from send_join, store information in the database about it: * store a record about the room as a whole having partial state, and stash the list of member servers too. * flag the join event itself as having partial state * also, for any new events whose prev-events are partial-stated, note that they will *also* be partial-stated. We don't yet make any attempt to interpret this data, so API calls (and a bunch of other things) are just going to get incorrect data.
This commit is contained in:
parent
4ccc2d09aa
commit
e2e1d90a5e
12 changed files with 297 additions and 32 deletions
1
changelog.d/12012.misc
Normal file
1
changelog.d/12012.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database.
|
|
@ -101,6 +101,9 @@ class EventContext:
|
||||||
|
|
||||||
As with _current_state_ids, this is a private attribute. It should be
|
As with _current_state_ids, this is a private attribute. It should be
|
||||||
accessed via get_prev_state_ids.
|
accessed via get_prev_state_ids.
|
||||||
|
|
||||||
|
partial_state: if True, we may be storing this event with a temporary,
|
||||||
|
incomplete state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rejected: Union[bool, str] = False
|
rejected: Union[bool, str] = False
|
||||||
|
@ -113,12 +116,15 @@ class EventContext:
|
||||||
_current_state_ids: Optional[StateMap[str]] = None
|
_current_state_ids: Optional[StateMap[str]] = None
|
||||||
_prev_state_ids: Optional[StateMap[str]] = None
|
_prev_state_ids: Optional[StateMap[str]] = None
|
||||||
|
|
||||||
|
partial_state: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_state(
|
def with_state(
|
||||||
state_group: Optional[int],
|
state_group: Optional[int],
|
||||||
state_group_before_event: Optional[int],
|
state_group_before_event: Optional[int],
|
||||||
current_state_ids: Optional[StateMap[str]],
|
current_state_ids: Optional[StateMap[str]],
|
||||||
prev_state_ids: Optional[StateMap[str]],
|
prev_state_ids: Optional[StateMap[str]],
|
||||||
|
partial_state: bool,
|
||||||
prev_group: Optional[int] = None,
|
prev_group: Optional[int] = None,
|
||||||
delta_ids: Optional[StateMap[str]] = None,
|
delta_ids: Optional[StateMap[str]] = None,
|
||||||
) -> "EventContext":
|
) -> "EventContext":
|
||||||
|
@ -129,6 +135,7 @@ class EventContext:
|
||||||
state_group_before_event=state_group_before_event,
|
state_group_before_event=state_group_before_event,
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
|
partial_state=partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -170,6 +177,7 @@ class EventContext:
|
||||||
"prev_group": self.prev_group,
|
"prev_group": self.prev_group,
|
||||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||||
"app_service_id": self.app_service.id if self.app_service else None,
|
"app_service_id": self.app_service.id if self.app_service else None,
|
||||||
|
"partial_state": self.partial_state,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -196,6 +204,7 @@ class EventContext:
|
||||||
prev_group=input["prev_group"],
|
prev_group=input["prev_group"],
|
||||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||||
rejected=input["rejected"],
|
rejected=input["rejected"],
|
||||||
|
partial_state=input.get("partial_state", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
app_service_id = input["app_service_id"]
|
app_service_id = input["app_service_id"]
|
||||||
|
|
|
@ -519,8 +519,17 @@ class FederationHandler:
|
||||||
state_events=state,
|
state_events=state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ret.partial_state:
|
||||||
|
await self.store.store_partial_state_room(room_id, ret.servers_in_room)
|
||||||
|
|
||||||
max_stream_id = await self._federation_event_handler.process_remote_join(
|
max_stream_id = await self._federation_event_handler.process_remote_join(
|
||||||
origin, room_id, auth_chain, state, event, room_version_obj
|
origin,
|
||||||
|
room_id,
|
||||||
|
auth_chain,
|
||||||
|
state,
|
||||||
|
event,
|
||||||
|
room_version_obj,
|
||||||
|
partial_state=ret.partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We wait here until this instance has seen the events come down
|
# We wait here until this instance has seen the events come down
|
||||||
|
|
|
@ -397,6 +397,7 @@ class FederationEventHandler:
|
||||||
state: List[EventBase],
|
state: List[EventBase],
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
|
partial_state: bool,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Persists the events returned by a send_join
|
"""Persists the events returned by a send_join
|
||||||
|
|
||||||
|
@ -412,6 +413,7 @@ class FederationEventHandler:
|
||||||
event
|
event
|
||||||
room_version: The room version we expect this room to have, and
|
room_version: The room version we expect this room to have, and
|
||||||
will raise if it doesn't match the version in the create event.
|
will raise if it doesn't match the version in the create event.
|
||||||
|
partial_state: True if the state omits non-critical membership events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stream ID after which all events have been persisted.
|
The stream ID after which all events have been persisted.
|
||||||
|
@ -453,10 +455,14 @@ class FederationEventHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
# and now persist the join event itself.
|
# and now persist the join event itself.
|
||||||
logger.info("Peristing join-via-remote %s", event)
|
logger.info(
|
||||||
|
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
|
||||||
|
)
|
||||||
with nested_logging_context(suffix=event.event_id):
|
with nested_logging_context(suffix=event.event_id):
|
||||||
context = await self._state_handler.compute_event_context(
|
context = await self._state_handler.compute_event_context(
|
||||||
event, old_state=state
|
event,
|
||||||
|
old_state=state,
|
||||||
|
partial_state=partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = await self._check_event_auth(origin, event, context)
|
context = await self._check_event_auth(origin, event, context)
|
||||||
|
@ -698,6 +704,8 @@ class FederationEventHandler:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = await self._resolve_state_at_missing_prevs(origin, event)
|
state = await self._resolve_state_at_missing_prevs(origin, event)
|
||||||
|
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
|
||||||
|
# not return partial state
|
||||||
await self._process_received_pdu(
|
await self._process_received_pdu(
|
||||||
origin, event, state=state, backfilled=backfilled
|
origin, event, state=state, backfilled=backfilled
|
||||||
)
|
)
|
||||||
|
@ -1791,6 +1799,7 @@ class FederationEventHandler:
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=state_updates,
|
delta_ids=state_updates,
|
||||||
|
partial_state=context.partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_push_actions_and_persist_event(
|
async def _run_push_actions_and_persist_event(
|
||||||
|
|
|
@ -992,6 +992,8 @@ class EventCreationHandler:
|
||||||
and full_state_ids_at_event
|
and full_state_ids_at_event
|
||||||
and builder.internal_metadata.is_historical()
|
and builder.internal_metadata.is_historical()
|
||||||
):
|
):
|
||||||
|
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||||
|
# old state is complete.
|
||||||
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
|
old_state = await self.store.get_events_as_list(full_state_ids_at_event)
|
||||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -258,7 +258,10 @@ class StateHandler:
|
||||||
return await self.store.get_joined_hosts(room_id, entry)
|
return await self.store.get_joined_hosts(room_id, entry)
|
||||||
|
|
||||||
async def compute_event_context(
|
async def compute_event_context(
|
||||||
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
|
self,
|
||||||
|
event: EventBase,
|
||||||
|
old_state: Optional[Iterable[EventBase]] = None,
|
||||||
|
partial_state: bool = False,
|
||||||
) -> EventContext:
|
) -> EventContext:
|
||||||
"""Build an EventContext structure for a non-outlier event.
|
"""Build an EventContext structure for a non-outlier event.
|
||||||
|
|
||||||
|
@ -273,6 +276,8 @@ class StateHandler:
|
||||||
calculated from existing events. This is normally only specified
|
calculated from existing events. This is normally only specified
|
||||||
when receiving an event from federation where we don't have the
|
when receiving an event from federation where we don't have the
|
||||||
prev events for, e.g. when backfilling.
|
prev events for, e.g. when backfilling.
|
||||||
|
partial_state: True if `old_state` is partial and omits non-critical
|
||||||
|
membership events
|
||||||
Returns:
|
Returns:
|
||||||
The event context.
|
The event context.
|
||||||
"""
|
"""
|
||||||
|
@ -295,8 +300,28 @@ class StateHandler:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# otherwise, we'll need to resolve the state across the prev_events.
|
# otherwise, we'll need to resolve the state across the prev_events.
|
||||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
|
||||||
|
|
||||||
|
# partial_state should not be set explicitly in this case:
|
||||||
|
# we work it out dynamically
|
||||||
|
assert not partial_state
|
||||||
|
|
||||||
|
# if any of the prev-events have partial state, so do we.
|
||||||
|
# (This is slightly racy - the prev-events might get fixed up before we use
|
||||||
|
# their states - but I don't think that really matters; it just means we
|
||||||
|
# might redundantly recalculate the state for this event later.)
|
||||||
|
prev_event_ids = event.prev_event_ids()
|
||||||
|
incomplete_prev_events = await self.store.get_partial_state_events(
|
||||||
|
prev_event_ids
|
||||||
|
)
|
||||||
|
if any(incomplete_prev_events.values()):
|
||||||
|
logger.debug(
|
||||||
|
"New/incoming event %s refers to prev_events %s with partial state",
|
||||||
|
event.event_id,
|
||||||
|
[k for (k, v) in incomplete_prev_events.items() if v],
|
||||||
|
)
|
||||||
|
partial_state = True
|
||||||
|
|
||||||
|
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||||
entry = await self.resolve_state_groups_for_events(
|
entry = await self.resolve_state_groups_for_events(
|
||||||
event.room_id, event.prev_event_ids()
|
event.room_id, event.prev_event_ids()
|
||||||
)
|
)
|
||||||
|
@ -342,6 +367,7 @@ class StateHandler:
|
||||||
prev_state_ids=state_ids_before_event,
|
prev_state_ids=state_ids_before_event,
|
||||||
prev_group=state_group_before_event_prev_group,
|
prev_group=state_group_before_event_prev_group,
|
||||||
delta_ids=deltas_to_state_group_before_event,
|
delta_ids=deltas_to_state_group_before_event,
|
||||||
|
partial_state=partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -373,6 +399,7 @@ class StateHandler:
|
||||||
prev_state_ids=state_ids_before_event,
|
prev_state_ids=state_ids_before_event,
|
||||||
prev_group=state_group_before_event,
|
prev_group=state_group_before_event,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
|
partial_state=partial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
@measure_func()
|
@measure_func()
|
||||||
|
|
|
@ -2145,6 +2145,14 @@ class PersistEventsStore:
|
||||||
state_groups = {}
|
state_groups = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
|
# double-check that we don't have any events that claim to be outliers
|
||||||
|
# *and* have partial state (which is meaningless: we should have no
|
||||||
|
# state at all for an outlier)
|
||||||
|
if context.partial_state:
|
||||||
|
raise ValueError(
|
||||||
|
"Outlier event %s claims to have partial state", event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if the event was rejected, just give it the same state as its
|
# if the event was rejected, just give it the same state as its
|
||||||
|
@ -2155,6 +2163,23 @@ class PersistEventsStore:
|
||||||
|
|
||||||
state_groups[event.event_id] = context.state_group
|
state_groups[event.event_id] = context.state_group
|
||||||
|
|
||||||
|
# if we have partial state for these events, record the fact. (This happens
|
||||||
|
# here rather than in _store_event_txn because it also needs to happen when
|
||||||
|
# we de-outlier an event.)
|
||||||
|
self.db_pool.simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="partial_state_events",
|
||||||
|
keys=("room_id", "event_id"),
|
||||||
|
values=[
|
||||||
|
(
|
||||||
|
event.room_id,
|
||||||
|
event.event_id,
|
||||||
|
)
|
||||||
|
for event, ctx in events_and_contexts
|
||||||
|
if ctx.partial_state
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
self.db_pool.simple_upsert_many_txn(
|
self.db_pool.simple_upsert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
|
|
|
@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
"get_event_id_for_timestamp_txn",
|
"get_event_id_for_timestamp_txn",
|
||||||
get_event_id_for_timestamp_txn,
|
get_event_id_for_timestamp_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedList("is_partial_state_event", list_name="event_ids")
|
||||||
|
async def get_partial_state_events(
|
||||||
|
self, event_ids: Collection[str]
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
"""Checks which of the given events have partial state"""
|
||||||
|
result = await self.db_pool.simple_select_many_batch(
|
||||||
|
table="partial_state_events",
|
||||||
|
column="event_id",
|
||||||
|
iterable=event_ids,
|
||||||
|
retcols=["event_id"],
|
||||||
|
desc="get_partial_state_events",
|
||||||
|
)
|
||||||
|
# convert the result to a dict, to make @cachedList work
|
||||||
|
partial = {r["event_id"] for r in result}
|
||||||
|
return {e_id: e_id in partial for e_id in event_ids}
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def is_partial_state_event(self, event_id: str) -> bool:
|
||||||
|
"""Checks if the given event has partial state"""
|
||||||
|
result = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="partial_state_events",
|
||||||
|
keyvalues={"event_id": event_id},
|
||||||
|
retcol="1",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_partial_state_event",
|
||||||
|
)
|
||||||
|
return result is not None
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def store_partial_state_room(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
servers: Collection[str],
|
||||||
|
) -> None:
|
||||||
|
"""Mark the given room as containing events with partial state
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: the ID of the room
|
||||||
|
servers: other servers known to be in the room
|
||||||
|
"""
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"store_partial_state_room",
|
||||||
|
self._store_partial_state_room_txn,
|
||||||
|
room_id,
|
||||||
|
servers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _store_partial_state_room_txn(
|
||||||
|
txn: LoggingTransaction, room_id: str, servers: Collection[str]
|
||||||
|
) -> None:
|
||||||
|
DatabasePool.simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="partial_state_rooms",
|
||||||
|
values={
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
DatabasePool.simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="partial_state_rooms_servers",
|
||||||
|
keys=("room_id", "server_name"),
|
||||||
|
values=((room_id, s) for s in servers),
|
||||||
|
)
|
||||||
|
|
||||||
async def maybe_store_room_on_outlier_membership(
|
async def maybe_store_room_on_outlier_membership(
|
||||||
self, room_id: str, room_version: RoomVersion
|
self, room_id: str, room_version: RoomVersion
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- rooms which we have done a partial-state-style join to
|
||||||
|
CREATE TABLE IF NOT EXISTS partial_state_rooms (
|
||||||
|
room_id TEXT PRIMARY KEY,
|
||||||
|
FOREIGN KEY(room_id) REFERENCES rooms(room_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- a list of remote servers we believe are in the room
|
||||||
|
CREATE TABLE IF NOT EXISTS partial_state_rooms_servers (
|
||||||
|
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
UNIQUE(room_id, server_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- a list of events with partial state. We can't store this in the `events` table
|
||||||
|
-- itself, because `events` is meant to be append-only.
|
||||||
|
CREATE TABLE IF NOT EXISTS partial_state_events (
|
||||||
|
-- the room_id is denormalised for efficient indexing (the canonical source is `events`)
|
||||||
|
room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
|
||||||
|
event_id TEXT NOT NULL REFERENCES events(event_id),
|
||||||
|
UNIQUE(event_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx
|
||||||
|
ON partial_state_events (room_id);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This migration adds triggers to the partial_state_events tables to enforce uniqueness
|
||||||
|
|
||||||
|
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
||||||
|
"""
|
||||||
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
|
||||||
|
# complain if the room_id in partial_state_events doesn't match
|
||||||
|
# that in `events`. We already have a fk constraint which ensures that the event
|
||||||
|
# exists in `events`, so all we have to do is raise if there is a row with a
|
||||||
|
# matching stream_ordering but not a matching room_id.
|
||||||
|
if isinstance(database_engine, Sqlite3Engine):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id
|
||||||
|
BEFORE INSERT ON partial_state_events
|
||||||
|
FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events')
|
||||||
|
WHERE EXISTS (
|
||||||
|
SELECT 1 FROM events
|
||||||
|
WHERE events.event_id = NEW.event_id
|
||||||
|
AND events.room_id != NEW.room_id
|
||||||
|
);
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
elif isinstance(database_engine, PostgresEngine):
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (
|
||||||
|
SELECT 1 FROM events
|
||||||
|
WHERE events.event_id = NEW.event_id
|
||||||
|
AND events.room_id != NEW.room_id
|
||||||
|
) THEN
|
||||||
|
RAISE EXCEPTION 'Incorrect room_id in partial_state_events';
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$BODY$ LANGUAGE plpgsql;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE PROCEDURE check_partial_state_events()
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown database engine")
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional
|
from typing import Collection, Dict, List, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
@ -70,7 +70,7 @@ def create_event(
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
class StateGroupStore:
|
class _DummyStore:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._event_to_state_group = {}
|
self._event_to_state_group = {}
|
||||||
self._group_to_state = {}
|
self._group_to_state = {}
|
||||||
|
@ -105,6 +105,11 @@ class StateGroupStore:
|
||||||
if e_id in self._event_id_to_event
|
if e_id in self._event_id_to_event
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_partial_state_events(
|
||||||
|
self, event_ids: Collection[str]
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
return {e: False for e in event_ids}
|
||||||
|
|
||||||
async def get_state_group_delta(self, name):
|
async def get_state_group_delta(self, name):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
@ -157,8 +162,8 @@ class Graph:
|
||||||
|
|
||||||
class StateTestCase(unittest.TestCase):
|
class StateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.store = StateGroupStore()
|
self.dummy_store = _DummyStore()
|
||||||
storage = Mock(main=self.store, state=self.store)
|
storage = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||||
hs = Mock(
|
hs = Mock(
|
||||||
spec_set=[
|
spec_set=[
|
||||||
"config",
|
"config",
|
||||||
|
@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs.config = default_config("tesths", True)
|
hs.config = default_config("tesths", True)
|
||||||
hs.get_datastores.return_value = Mock(main=self.store)
|
hs.get_datastores.return_value = Mock(main=self.dummy_store)
|
||||||
hs.get_state_handler.return_value = None
|
hs.get_state_handler.return_value = None
|
||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
|
@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
|
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.register_events(graph.walk())
|
self.dummy_store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store: dict[str, EventContext] = {}
|
context_store: dict[str, EventContext] = {}
|
||||||
|
|
||||||
|
@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event)
|
self.state.compute_event_context(event)
|
||||||
)
|
)
|
||||||
self.store.register_event_context(event, context)
|
self.dummy_store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
|
@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
|
edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.register_events(graph.walk())
|
self.dummy_store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event)
|
self.state.compute_event_context(event)
|
||||||
)
|
)
|
||||||
self.store.register_event_context(event, context)
|
self.dummy_store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
# C ends up winning the resolution between B and C
|
# C ends up winning the resolution between B and C
|
||||||
|
@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
|
edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.register_events(graph.walk())
|
self.dummy_store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event)
|
self.state.compute_event_context(event)
|
||||||
)
|
)
|
||||||
self.store.register_event_context(event, context)
|
self.dummy_store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
# C ends up winning the resolution between C and D because bans win over other
|
# C ends up winning the resolution between C and D because bans win over other
|
||||||
|
@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
self._add_depths(nodes, edges)
|
self._add_depths(nodes, edges)
|
||||||
graph = Graph(nodes, edges)
|
graph = Graph(nodes, edges)
|
||||||
|
|
||||||
self.store.register_events(graph.walk())
|
self.dummy_store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
context = yield defer.ensureDeferred(
|
context = yield defer.ensureDeferred(
|
||||||
self.state.compute_event_context(event)
|
self.state.compute_event_context(event)
|
||||||
)
|
)
|
||||||
self.store.register_event_context(event, context)
|
self.dummy_store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
# B ends up winning the resolution between B and C because power levels
|
# B ends up winning the resolution between B and C because power levels
|
||||||
|
@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
group_name = yield defer.ensureDeferred(
|
group_name = yield defer.ensureDeferred(
|
||||||
self.store.store_state_group(
|
self.dummy_store.store_state_group(
|
||||||
prev_event_id,
|
prev_event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
None,
|
None,
|
||||||
|
@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
{(e.type, e.state_key): e.event_id for e in old_state},
|
{(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.register_event_id_state_group(prev_event_id, group_name)
|
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
|
||||||
|
|
||||||
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
||||||
|
|
||||||
|
@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
group_name = yield defer.ensureDeferred(
|
group_name = yield defer.ensureDeferred(
|
||||||
self.store.store_state_group(
|
self.dummy_store.store_state_group(
|
||||||
prev_event_id,
|
prev_event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
None,
|
None,
|
||||||
|
@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
{(e.type, e.state_key): e.event_id for e in old_state},
|
{(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.register_event_id_state_group(prev_event_id, group_name)
|
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
|
||||||
|
|
||||||
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
||||||
|
|
||||||
|
@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.store.register_events(old_state_1)
|
self.dummy_store.register_events(old_state_1)
|
||||||
self.store.register_events(old_state_2)
|
self.dummy_store.register_events(old_state_2)
|
||||||
|
|
||||||
context = yield self._get_context(
|
context = yield self._get_context(
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
|
@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = _DummyStore()
|
||||||
store.register_events(old_state_1)
|
store.register_events(old_state_1)
|
||||||
store.register_events(old_state_2)
|
store.register_events(old_state_2)
|
||||||
self.store.get_events = store.get_events
|
self.dummy_store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(
|
context = yield self._get_context(
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
|
@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test1", state_key="1", depth=2),
|
create_event(type="test1", state_key="1", depth=2),
|
||||||
]
|
]
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = _DummyStore()
|
||||||
store.register_events(old_state_1)
|
store.register_events(old_state_1)
|
||||||
store.register_events(old_state_2)
|
store.register_events(old_state_2)
|
||||||
self.store.get_events = store.get_events
|
self.dummy_store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(
|
context = yield self._get_context(
|
||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
|
||||||
|
@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
|
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
|
||||||
):
|
):
|
||||||
sg1 = yield defer.ensureDeferred(
|
sg1 = yield defer.ensureDeferred(
|
||||||
self.store.store_state_group(
|
self.dummy_store.store_state_group(
|
||||||
prev_event_id_1,
|
prev_event_id_1,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
None,
|
None,
|
||||||
|
@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
{(e.type, e.state_key): e.event_id for e in old_state_1},
|
{(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.register_event_id_state_group(prev_event_id_1, sg1)
|
self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
|
||||||
|
|
||||||
sg2 = yield defer.ensureDeferred(
|
sg2 = yield defer.ensureDeferred(
|
||||||
self.store.store_state_group(
|
self.dummy_store.store_state_group(
|
||||||
prev_event_id_2,
|
prev_event_id_2,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
None,
|
None,
|
||||||
|
@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
{(e.type, e.state_key): e.event_id for e in old_state_2},
|
{(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.register_event_id_state_group(prev_event_id_2, sg2)
|
self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
|
||||||
|
|
||||||
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Reference in a new issue