forked from MirrorHub/synapse
Batch up storing state groups when creating new room (#14918)
This commit is contained in:
parent
335f52d595
commit
1c95ddd09b
14 changed files with 371 additions and 49 deletions
1
changelog.d/14918.misc
Normal file
1
changelog.d/14918.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Batch up storing state groups when creating a new room.
|
|
@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.storage.databases import StateGroupDataStore
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types.state import StateFilter
|
||||
|
||||
|
@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
|
|||
partial_state: bool
|
||||
state_map_before_event: Optional[StateMap[str]] = None
|
||||
|
||||
@classmethod
|
||||
async def batch_persist_unpersisted_contexts(
|
||||
cls,
|
||||
events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
|
||||
room_id: str,
|
||||
last_known_state_group: int,
|
||||
datastore: "StateGroupDataStore",
|
||||
) -> List[Tuple[EventBase, EventContext]]:
|
||||
"""
|
||||
Takes a list of events and their associated unpersisted contexts and persists
|
||||
the unpersisted contexts, returning a list of events and persisted contexts.
|
||||
Note that all the events must be in a linear chain (ie a <- b <- c).
|
||||
|
||||
Args:
|
||||
events_and_context: A list of events and their unpersisted contexts
|
||||
room_id: the room_id for the events
|
||||
last_known_state_group: the last persisted state group
|
||||
datastore: a state datastore
|
||||
"""
|
||||
amended_events_and_context = await datastore.store_state_deltas_for_batched(
|
||||
events_and_context, room_id, last_known_state_group
|
||||
)
|
||||
|
||||
events_and_persisted_context = []
|
||||
for event, unpersisted_context in amended_events_and_context:
|
||||
if event.is_state():
|
||||
context = EventContext(
|
||||
storage=unpersisted_context._storage,
|
||||
state_group=unpersisted_context.state_group_after_event,
|
||||
state_group_before_event=unpersisted_context.state_group_before_event,
|
||||
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
|
||||
partial_state=unpersisted_context.partial_state,
|
||||
prev_group=unpersisted_context.state_group_before_event,
|
||||
delta_ids=unpersisted_context.state_delta_due_to_event,
|
||||
)
|
||||
else:
|
||||
context = EventContext(
|
||||
storage=unpersisted_context._storage,
|
||||
state_group=unpersisted_context.state_group_after_event,
|
||||
state_group_before_event=unpersisted_context.state_group_before_event,
|
||||
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
|
||||
partial_state=unpersisted_context.partial_state,
|
||||
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
|
||||
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
|
||||
)
|
||||
events_and_persisted_context.append((event, context))
|
||||
return events_and_persisted_context
|
||||
|
||||
async def get_prev_state_ids(
|
||||
self, state_filter: Optional["StateFilter"] = None
|
||||
) -> StateMap[str]:
|
||||
|
|
|
@ -574,7 +574,7 @@ class EventCreationHandler:
|
|||
state_map: Optional[StateMap[str]] = None,
|
||||
for_batch: bool = False,
|
||||
current_state_group: Optional[int] = None,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
) -> Tuple[EventBase, UnpersistedEventContextBase]:
|
||||
"""
|
||||
Given a dict from a client, create a new event. If bool for_batch is true, will
|
||||
create an event using the prev_event_ids, and will create an event context for
|
||||
|
@ -721,8 +721,6 @@ class EventCreationHandler:
|
|||
current_state_group=current_state_group,
|
||||
)
|
||||
|
||||
context = await unpersisted_context.persist(event)
|
||||
|
||||
# In an ideal world we wouldn't need the second part of this condition. However,
|
||||
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
|
||||
# behaviour. Another reason is that this code is also evaluated each time a new
|
||||
|
@ -739,7 +737,7 @@ class EventCreationHandler:
|
|||
assert state_map is not None
|
||||
prev_event_id = state_map.get((EventTypes.Member, event.sender))
|
||||
else:
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
prev_state_ids = await unpersisted_context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
|
||||
|
@ -764,8 +762,7 @@ class EventCreationHandler:
|
|||
)
|
||||
|
||||
self.validator.validate_new(event, self.config)
|
||||
|
||||
return event, context
|
||||
return event, unpersisted_context
|
||||
|
||||
async def _is_exempt_from_privacy_policy(
|
||||
self, builder: EventBuilder, requester: Requester
|
||||
|
@ -1005,7 +1002,7 @@ class EventCreationHandler:
|
|||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
event, context = await self.create_event(
|
||||
event, unpersisted_context = await self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
|
@ -1016,6 +1013,7 @@ class EventCreationHandler:
|
|||
historical=historical,
|
||||
depth=depth,
|
||||
)
|
||||
context = await unpersisted_context.persist(event)
|
||||
|
||||
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
|
||||
event.sender,
|
||||
|
@ -1190,7 +1188,6 @@ class EventCreationHandler:
|
|||
if for_batch:
|
||||
assert prev_event_ids is not None
|
||||
assert state_map is not None
|
||||
assert current_state_group is not None
|
||||
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
|
||||
event = await builder.build(
|
||||
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
|
||||
|
@ -2046,7 +2043,7 @@ class EventCreationHandler:
|
|||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
event, context = await self.create_event(
|
||||
event, unpersisted_context = await self.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Dummy,
|
||||
|
@ -2055,6 +2052,7 @@ class EventCreationHandler:
|
|||
"sender": user_id,
|
||||
},
|
||||
)
|
||||
context = await unpersisted_context.persist(event)
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.event_auth import validate_event_for_room_version
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import UnpersistedEventContext
|
||||
from synapse.events.utils import copy_and_fixup_power_levels_contents
|
||||
from synapse.handlers.relations import BundledAggregations
|
||||
from synapse.module_api import NOT_SPAM
|
||||
|
@ -211,7 +212,7 @@ class RoomCreationHandler:
|
|||
# the required power level to send the tombstone event.
|
||||
(
|
||||
tombstone_event,
|
||||
tombstone_context,
|
||||
tombstone_unpersisted_context,
|
||||
) = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
|
@ -225,6 +226,9 @@ class RoomCreationHandler:
|
|||
},
|
||||
},
|
||||
)
|
||||
tombstone_context = await tombstone_unpersisted_context.persist(
|
||||
tombstone_event
|
||||
)
|
||||
validate_event_for_room_version(tombstone_event)
|
||||
await self._event_auth_handler.check_auth_rules_from_context(
|
||||
tombstone_event
|
||||
|
@ -1092,7 +1096,7 @@ class RoomCreationHandler:
|
|||
content: JsonDict,
|
||||
for_batch: bool,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
|
||||
) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
|
||||
"""
|
||||
Creates an event and associated event context.
|
||||
Args:
|
||||
|
@ -1111,20 +1115,23 @@ class RoomCreationHandler:
|
|||
|
||||
event_dict = create_event_dict(etype, content, **kwargs)
|
||||
|
||||
new_event, new_context = await self.event_creation_handler.create_event(
|
||||
(
|
||||
new_event,
|
||||
new_unpersisted_context,
|
||||
) = await self.event_creation_handler.create_event(
|
||||
creator,
|
||||
event_dict,
|
||||
prev_event_ids=prev_event,
|
||||
depth=depth,
|
||||
state_map=state_map,
|
||||
for_batch=for_batch,
|
||||
current_state_group=current_state_group,
|
||||
)
|
||||
|
||||
depth += 1
|
||||
prev_event = [new_event.event_id]
|
||||
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
|
||||
|
||||
return new_event, new_context
|
||||
return new_event, new_unpersisted_context
|
||||
|
||||
try:
|
||||
config = self._presets_dict[preset_config]
|
||||
|
@ -1134,10 +1141,10 @@ class RoomCreationHandler:
|
|||
)
|
||||
|
||||
creation_content.update({"creator": creator_id})
|
||||
creation_event, creation_context = await create_event(
|
||||
creation_event, unpersisted_creation_context = await create_event(
|
||||
EventTypes.Create, creation_content, False
|
||||
)
|
||||
|
||||
creation_context = await unpersisted_creation_context.persist(creation_event)
|
||||
logger.debug("Sending %s in new room", EventTypes.Member)
|
||||
ev = await self.event_creation_handler.handle_new_client_event(
|
||||
requester=creator,
|
||||
|
@ -1181,7 +1188,6 @@ class RoomCreationHandler:
|
|||
power_event, power_context = await create_event(
|
||||
EventTypes.PowerLevels, pl_content, True
|
||||
)
|
||||
current_state_group = power_context._state_group
|
||||
events_to_send.append((power_event, power_context))
|
||||
else:
|
||||
power_level_content: JsonDict = {
|
||||
|
@ -1230,14 +1236,12 @@ class RoomCreationHandler:
|
|||
power_level_content,
|
||||
True,
|
||||
)
|
||||
current_state_group = pl_context._state_group
|
||||
events_to_send.append((pl_event, pl_context))
|
||||
|
||||
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
|
||||
room_alias_event, room_alias_context = await create_event(
|
||||
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
|
||||
)
|
||||
current_state_group = room_alias_context._state_group
|
||||
events_to_send.append((room_alias_event, room_alias_context))
|
||||
|
||||
if (EventTypes.JoinRules, "") not in initial_state:
|
||||
|
@ -1246,7 +1250,6 @@ class RoomCreationHandler:
|
|||
{"join_rule": config["join_rules"]},
|
||||
True,
|
||||
)
|
||||
current_state_group = join_rules_context._state_group
|
||||
events_to_send.append((join_rules_event, join_rules_context))
|
||||
|
||||
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
|
||||
|
@ -1255,7 +1258,6 @@ class RoomCreationHandler:
|
|||
{"history_visibility": config["history_visibility"]},
|
||||
True,
|
||||
)
|
||||
current_state_group = visibility_context._state_group
|
||||
events_to_send.append((visibility_event, visibility_context))
|
||||
|
||||
if config["guest_can_join"]:
|
||||
|
@ -1265,14 +1267,12 @@ class RoomCreationHandler:
|
|||
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
|
||||
True,
|
||||
)
|
||||
current_state_group = guest_access_context._state_group
|
||||
events_to_send.append((guest_access_event, guest_access_context))
|
||||
|
||||
for (etype, state_key), content in initial_state.items():
|
||||
event, context = await create_event(
|
||||
etype, content, True, state_key=state_key
|
||||
)
|
||||
current_state_group = context._state_group
|
||||
events_to_send.append((event, context))
|
||||
|
||||
if config["encrypted"]:
|
||||
|
@ -1284,9 +1284,16 @@ class RoomCreationHandler:
|
|||
)
|
||||
events_to_send.append((encryption_event, encryption_context))
|
||||
|
||||
datastore = self.hs.get_datastores().state
|
||||
events_and_context = (
|
||||
await UnpersistedEventContext.batch_persist_unpersisted_contexts(
|
||||
events_to_send, room_id, current_state_group, datastore
|
||||
)
|
||||
)
|
||||
|
||||
last_event = await self.event_creation_handler.handle_new_client_event(
|
||||
creator,
|
||||
events_to_send,
|
||||
events_and_context,
|
||||
ignore_shadow_ban=True,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
|
|
@ -327,7 +327,7 @@ class RoomBatchHandler:
|
|||
# Mark all events as historical
|
||||
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
|
||||
|
||||
event, context = await self.event_creation_handler.create_event(
|
||||
event, unpersisted_context = await self.event_creation_handler.create_event(
|
||||
await self.create_requester_for_user_id_from_app_service(
|
||||
ev["sender"], app_service_requester.app_service
|
||||
),
|
||||
|
@ -345,7 +345,7 @@ class RoomBatchHandler:
|
|||
historical=True,
|
||||
depth=inherited_depth,
|
||||
)
|
||||
|
||||
context = await unpersisted_context.persist(event)
|
||||
assert context._state_group
|
||||
|
||||
# Normally this is done when persisting the event but we have to
|
||||
|
|
|
@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
event, context = await self.event_creation_handler.create_event(
|
||||
(
|
||||
event,
|
||||
unpersisted_context,
|
||||
) = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
|
@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
outlier=outlier,
|
||||
historical=historical,
|
||||
)
|
||||
|
||||
context = await unpersisted_context.persist(event)
|
||||
prev_state_ids = await context.get_prev_state_ids(
|
||||
StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
|
@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
event, context = await self.event_creation_handler.create_event(
|
||||
(
|
||||
event,
|
||||
unpersisted_context,
|
||||
) = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
|
@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
auth_event_ids=auth_event_ids,
|
||||
outlier=True,
|
||||
)
|
||||
context = await unpersisted_context.persist(event)
|
||||
event.internal_metadata.out_of_band_membership = True
|
||||
|
||||
result_event = (
|
||||
|
|
|
@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
|
|||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
|
@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
fetched_keys=non_member_types,
|
||||
)
|
||||
|
||||
async def store_state_deltas_for_batched(
|
||||
self,
|
||||
events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
|
||||
room_id: str,
|
||||
prev_group: int,
|
||||
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
|
||||
"""Generate and store state deltas for a group of events and contexts created to be
|
||||
batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
|
||||
|
||||
Args:
|
||||
events_and_context: the events to generate and store a state groups for
|
||||
and their associated contexts
|
||||
room_id: the id of the room the events were created for
|
||||
prev_group: the state group of the last event persisted before the batched events
|
||||
were created
|
||||
"""
|
||||
|
||||
def insert_deltas_group_txn(
|
||||
txn: LoggingTransaction,
|
||||
events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
|
||||
prev_group: int,
|
||||
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
|
||||
"""Generate and store state groups for the provided events and contexts.
|
||||
|
||||
Requires that we have the state as a delta from the last persisted state group.
|
||||
|
||||
Returns:
|
||||
A list of state groups
|
||||
"""
|
||||
is_in_db = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keyvalues={"id": prev_group},
|
||||
retcol="id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not is_in_db:
|
||||
raise Exception(
|
||||
"Trying to persist state with unpersisted prev_group: %r"
|
||||
% (prev_group,)
|
||||
)
|
||||
|
||||
num_state_groups = sum(
|
||||
1 for event, _ in events_and_context if event.is_state()
|
||||
)
|
||||
|
||||
state_groups = self._state_group_seq_gen.get_next_mult_txn(
|
||||
txn, num_state_groups
|
||||
)
|
||||
|
||||
sg_before = prev_group
|
||||
state_group_iter = iter(state_groups)
|
||||
for event, context in events_and_context:
|
||||
if not event.is_state():
|
||||
context.state_group_after_event = sg_before
|
||||
context.state_group_before_event = sg_before
|
||||
continue
|
||||
|
||||
sg_after = next(state_group_iter)
|
||||
context.state_group_after_event = sg_after
|
||||
context.state_group_before_event = sg_before
|
||||
context.state_delta_due_to_event = {
|
||||
(event.type, event.state_key): event.event_id
|
||||
}
|
||||
sg_before = sg_after
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keys=("id", "room_id", "event_id"),
|
||||
values=[
|
||||
(context.state_group_after_event, room_id, event.event_id)
|
||||
for event, context in events_and_context
|
||||
if event.is_state()
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keys=("state_group", "prev_state_group"),
|
||||
values=[
|
||||
(
|
||||
context.state_group_after_event,
|
||||
context.state_group_before_event,
|
||||
)
|
||||
for event, context in events_and_context
|
||||
if event.is_state()
|
||||
],
|
||||
)
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
keys=("state_group", "room_id", "type", "state_key", "event_id"),
|
||||
values=[
|
||||
(
|
||||
context.state_group_after_event,
|
||||
room_id,
|
||||
key[0],
|
||||
key[1],
|
||||
state_id,
|
||||
)
|
||||
for event, context in events_and_context
|
||||
if context.state_delta_due_to_event is not None
|
||||
for key, state_id in context.state_delta_due_to_event.items()
|
||||
],
|
||||
)
|
||||
return events_and_context
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"store_state_deltas_for_batched.insert_deltas_group",
|
||||
insert_deltas_group_txn,
|
||||
events_and_context,
|
||||
prev_group,
|
||||
)
|
||||
|
||||
async def store_state_group(
|
||||
self,
|
||||
event_id: str,
|
||||
|
|
|
@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
|
@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
return memberEvent, memberEventContext
|
||||
|
||||
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
|
||||
def _create_duplicate_event(
|
||||
self, txn_id: str
|
||||
) -> Tuple[EventBase, UnpersistedEventContextBase]:
|
||||
"""Create a new event with the given transaction ID. All events produced
|
||||
by this method will be considered duplicates.
|
||||
"""
|
||||
|
@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
txn_id = "something_suitably_random"
|
||||
|
||||
event1, context = self._create_duplicate_event(txn_id)
|
||||
event1, unpersisted_context = self._create_duplicate_event(txn_id)
|
||||
context = self.get_success(unpersisted_context.persist(event1))
|
||||
|
||||
ret_event1 = self.get_success(
|
||||
self.handler.handle_new_client_event(
|
||||
|
@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(event1.event_id, ret_event1.event_id)
|
||||
|
||||
event2, context = self._create_duplicate_event(txn_id)
|
||||
event2, unpersisted_context = self._create_duplicate_event(txn_id)
|
||||
context = self.get_success(unpersisted_context.persist(event2))
|
||||
|
||||
# We want to test that the deduplication at the persit event end works,
|
||||
# so we want to make sure we test with different events.
|
||||
|
@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Let's test that calling `persist_event` directly also does the right
|
||||
# thing.
|
||||
event3, context = self._create_duplicate_event(txn_id)
|
||||
event3, unpersisted_context = self._create_duplicate_event(txn_id)
|
||||
context = self.get_success(unpersisted_context.persist(event3))
|
||||
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
ret_event3, event_pos3, _ = self.get_success(
|
||||
|
@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Let's test that calling `persist_events` directly also does the right
|
||||
# thing.
|
||||
event4, context = self._create_duplicate_event(txn_id)
|
||||
event4, unpersisted_context = self._create_duplicate_event(txn_id)
|
||||
context = self.get_success(unpersisted_context.persist(event4))
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
|
@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||
txn_id = "something_else_suitably_random"
|
||||
|
||||
# Create two duplicate events to persist at the same time
|
||||
event1, context1 = self._create_duplicate_event(txn_id)
|
||||
event2, context2 = self._create_duplicate_event(txn_id)
|
||||
event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
|
||||
context1 = self.get_success(unpersisted_context1.persist(event1))
|
||||
event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
|
||||
context2 = self.get_success(unpersisted_context2.persist(event2))
|
||||
|
||||
# Ensure their event IDs are different to start with
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
|
|
@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
# Lower the permissions of the inviter.
|
||||
event_creation_handler = self.hs.get_event_creation_handler()
|
||||
requester = create_requester(inviter)
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
|
@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
self.get_success(
|
||||
event_creation_handler.handle_new_client_event(
|
||||
requester, events_and_context=[(event, context)]
|
||||
|
|
|
@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
|
||||
# Create a new message event, and try to evaluate it under the dodgy
|
||||
# power level event.
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
self.event_creation_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
prev_event_ids=[pl_event_id],
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
|
||||
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
|
||||
# should not raise
|
||||
|
@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
"""Ensure that push rules are not calculated when disabled in the config"""
|
||||
|
||||
# Create a new message event which should cause a notification.
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
self.event_creation_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
|
||||
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
|
||||
# Mock the method which calculates push rules -- we do this instead of
|
||||
|
@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
) -> bool:
|
||||
"""Returns true iff the `mentions` trigger an event push action."""
|
||||
# Create a new message event which should cause a notification.
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
self.event_creation_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
# Execute the push rule machinery.
|
||||
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
|
||||
|
||||
|
@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
|
||||
|
||||
# Create & persist an event to use as the parent of the relation.
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
self.event_creation_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
|||
},
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
self.get_success(
|
||||
self.event_creation_handler.handle_new_client_event(
|
||||
self.requester, events_and_context=[(event, context)]
|
||||
|
|
|
@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
assert channel.resource_usage is not None
|
||||
self.assertEqual(33, channel.resource_usage.db_txn_count)
|
||||
self.assertEqual(30, channel.resource_usage.db_txn_count)
|
||||
|
||||
def test_post_room_initial_state(self) -> None:
|
||||
# POST with initial_state config key, expect new room id
|
||||
|
@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
|
|||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
assert channel.resource_usage is not None
|
||||
self.assertEqual(36, channel.resource_usage.db_txn_count)
|
||||
self.assertEqual(32, channel.resource_usage.db_txn_count)
|
||||
|
||||
def test_post_room_visibility_key(self) -> None:
|
||||
# POST with visibility config key, expect new room id
|
||||
|
|
|
@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
latest_event_ids = self.get_success(
|
||||
self.store.get_prev_events_for_room(room_id)
|
||||
)
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
event_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(
|
||||
self.requester, events_and_context=[(event, context)]
|
||||
|
@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
assert state_ids1 is not None
|
||||
state1 = set(state_ids1.values())
|
||||
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
event_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
|
@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
|||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
)
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
self.get_success(
|
||||
event_handler.handle_new_client_event(
|
||||
self.requester, events_and_context=[(event, context)]
|
||||
|
|
|
@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
|
|||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
|
||||
|
||||
def test_batched_state_group_storing(self) -> None:
|
||||
creation_event = self.inject_state_event(
|
||||
self.room, self.u_alice, EventTypes.Create, "", {}
|
||||
)
|
||||
state_to_event = self.get_success(
|
||||
self.storage.state.get_state_groups(
|
||||
self.room.to_string(), [creation_event.event_id]
|
||||
)
|
||||
)
|
||||
current_state_group = list(state_to_event.keys())[0]
|
||||
|
||||
# create some unpersisted events and event contexts to store against room
|
||||
events_and_context = []
|
||||
builder = self.event_builder_factory.for_room_version(
|
||||
RoomVersions.V1,
|
||||
{
|
||||
"type": EventTypes.Name,
|
||||
"sender": self.u_alice.to_string(),
|
||||
"state_key": "",
|
||||
"room_id": self.room.to_string(),
|
||||
"content": {"name": "first rename of room"},
|
||||
},
|
||||
)
|
||||
|
||||
event1, unpersisted_context1 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
events_and_context.append((event1, unpersisted_context1))
|
||||
|
||||
builder2 = self.event_builder_factory.for_room_version(
|
||||
RoomVersions.V1,
|
||||
{
|
||||
"type": EventTypes.JoinRules,
|
||||
"sender": self.u_alice.to_string(),
|
||||
"state_key": "",
|
||||
"room_id": self.room.to_string(),
|
||||
"content": {"join_rule": "private"},
|
||||
},
|
||||
)
|
||||
|
||||
event2, unpersisted_context2 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder2)
|
||||
)
|
||||
events_and_context.append((event2, unpersisted_context2))
|
||||
|
||||
builder3 = self.event_builder_factory.for_room_version(
|
||||
RoomVersions.V1,
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"sender": self.u_alice.to_string(),
|
||||
"room_id": self.room.to_string(),
|
||||
"content": {"body": "hello from event 3", "msgtype": "m.text"},
|
||||
},
|
||||
)
|
||||
|
||||
event3, unpersisted_context3 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder3)
|
||||
)
|
||||
events_and_context.append((event3, unpersisted_context3))
|
||||
|
||||
builder4 = self.event_builder_factory.for_room_version(
|
||||
RoomVersions.V1,
|
||||
{
|
||||
"type": EventTypes.JoinRules,
|
||||
"sender": self.u_alice.to_string(),
|
||||
"state_key": "",
|
||||
"room_id": self.room.to_string(),
|
||||
"content": {"join_rule": "public"},
|
||||
},
|
||||
)
|
||||
|
||||
event4, unpersisted_context4 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder4)
|
||||
)
|
||||
events_and_context.append((event4, unpersisted_context4))
|
||||
|
||||
processed_events_and_context = self.get_success(
|
||||
self.hs.get_datastores().state.store_state_deltas_for_batched(
|
||||
events_and_context, self.room.to_string(), current_state_group
|
||||
)
|
||||
)
|
||||
|
||||
# check that only state events are in state_groups, and all state events are in state_groups
|
||||
res = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_groups",
|
||||
keyvalues=None,
|
||||
retcols=("event_id",),
|
||||
)
|
||||
)
|
||||
|
||||
events = []
|
||||
for result in res:
|
||||
self.assertNotIn(event3.event_id, result)
|
||||
events.append(result.get("event_id"))
|
||||
|
||||
for event, _ in processed_events_and_context:
|
||||
if event.is_state():
|
||||
self.assertIn(event.event_id, events)
|
||||
|
||||
# check that each unique state has state group in state_groups_state and that the
|
||||
# type/state key is correct, and check that each state event's state group
|
||||
# has an entry and prev event in state_group_edges
|
||||
for event, context in processed_events_and_context:
|
||||
if event.is_state():
|
||||
state = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": context.state_group_after_event},
|
||||
retcols=("type", "state_key"),
|
||||
)
|
||||
)
|
||||
self.assertEqual(event.type, state[0].get("type"))
|
||||
self.assertEqual(event.state_key, state[0].get("state_key"))
|
||||
|
||||
groups = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": str(context.state_group_after_event)},
|
||||
retcols=("*",),
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
context.state_group_before_event, groups[0].get("prev_state_group")
|
||||
)
|
||||
|
|
|
@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
|
|||
event_creator = self.hs.get_event_creation_handler()
|
||||
requester = create_requester(user)
|
||||
|
||||
event, context = self.get_success(
|
||||
event, unpersisted_context = self.get_success(
|
||||
event_creator.create_event(
|
||||
requester,
|
||||
{
|
||||
|
@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase):
|
|||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
)
|
||||
|
||||
context = self.get_success(unpersisted_context.persist(event))
|
||||
if soft_failed:
|
||||
event.internal_metadata.soft_failed = True
|
||||
|
||||
|
|
Loading…
Reference in a new issue