Batch up storing state groups when creating new room (#14918)

This commit is contained in:
Shay 2023-02-24 13:15:29 -08:00 committed by GitHub
parent 335f52d595
commit 1c95ddd09b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 371 additions and 49 deletions

1
changelog.d/14918.misc Normal file
View file

@ -0,0 +1 @@
Batch up storing state groups when creating a new room.

View file

@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None state_map_before_event: Optional[StateMap[str]] = None
@classmethod
async def batch_persist_unpersisted_contexts(
cls,
events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
room_id: str,
last_known_state_group: int,
datastore: "StateGroupDataStore",
) -> List[Tuple[EventBase, EventContext]]:
"""
Takes a list of events and their associated unpersisted contexts and persists
the unpersisted contexts, returning a list of events and persisted contexts.
Note that all the events must be in a linear chain (ie a <- b <- c).
Args:
events_and_context: A list of events and their unpersisted contexts
room_id: the room_id for the events
last_known_state_group: the last persisted state group
datastore: a state datastore
"""
amended_events_and_context = await datastore.store_state_deltas_for_batched(
events_and_context, room_id, last_known_state_group
)
events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context:
if event.is_state():
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.state_group_before_event,
delta_ids=unpersisted_context.state_delta_due_to_event,
)
else:
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
)
events_and_persisted_context.append((event, context))
return events_and_persisted_context
async def get_prev_state_ids( async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]: ) -> StateMap[str]:

View file

@ -574,7 +574,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None, state_map: Optional[StateMap[str]] = None,
for_batch: bool = False, for_batch: bool = False,
current_state_group: Optional[int] = None, current_state_group: Optional[int] = None,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, UnpersistedEventContextBase]:
""" """
Given a dict from a client, create a new event. If bool for_batch is true, will Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for create an event using the prev_event_ids, and will create an event context for
@ -721,8 +721,6 @@ class EventCreationHandler:
current_state_group=current_state_group, current_state_group=current_state_group,
) )
context = await unpersisted_context.persist(event)
# In an ideal world we wouldn't need the second part of this condition. However, # In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new # behaviour. Another reason is that this code is also evaluated each time a new
@ -739,7 +737,7 @@ class EventCreationHandler:
assert state_map is not None assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender)) prev_event_id = state_map.get((EventTypes.Member, event.sender))
else: else:
prev_state_ids = await context.get_prev_state_ids( prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)]) StateFilter.from_types([(EventTypes.Member, None)])
) )
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@ -764,8 +762,7 @@ class EventCreationHandler:
) )
self.validator.validate_new(event, self.config) self.validator.validate_new(event, self.config)
return event, unpersisted_context
return event, context
async def _is_exempt_from_privacy_policy( async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester self, builder: EventBuilder, requester: Requester
@ -1005,7 +1002,7 @@ class EventCreationHandler:
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, context = await self.create_event( event, unpersisted_context = await self.create_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -1016,6 +1013,7 @@ class EventCreationHandler:
historical=historical, historical=historical,
depth=depth, depth=depth,
) )
context = await unpersisted_context.persist(event)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender, event.sender,
@ -1190,7 +1188,6 @@ class EventCreationHandler:
if for_batch: if for_batch:
assert prev_event_ids is not None assert prev_event_ids is not None
assert state_map is not None assert state_map is not None
assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build( event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@ -2046,7 +2043,7 @@ class EventCreationHandler:
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, context = await self.create_event( event, unpersisted_context = await self.create_event(
requester, requester,
{ {
"type": EventTypes.Dummy, "type": EventTypes.Dummy,
@ -2055,6 +2052,7 @@ class EventCreationHandler:
"sender": user_id, "sender": user_id,
}, },
) )
context = await unpersisted_context.persist(event)
event.internal_metadata.proactively_send = False event.internal_metadata.proactively_send = False

View file

@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM from synapse.module_api import NOT_SPAM
@ -211,7 +212,7 @@ class RoomCreationHandler:
# the required power level to send the tombstone event. # the required power level to send the tombstone event.
( (
tombstone_event, tombstone_event,
tombstone_context, tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
{ {
@ -225,6 +226,9 @@ class RoomCreationHandler:
}, },
}, },
) )
tombstone_context = await tombstone_unpersisted_context.persist(
tombstone_event
)
validate_event_for_room_version(tombstone_event) validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context( await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event tombstone_event
@ -1092,7 +1096,7 @@ class RoomCreationHandler:
content: JsonDict, content: JsonDict,
for_batch: bool, for_batch: bool,
**kwargs: Any, **kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
""" """
Creates an event and associated event context. Creates an event and associated event context.
Args: Args:
@ -1111,20 +1115,23 @@ class RoomCreationHandler:
event_dict = create_event_dict(etype, content, **kwargs) event_dict = create_event_dict(etype, content, **kwargs)
new_event, new_context = await self.event_creation_handler.create_event( (
new_event,
new_unpersisted_context,
) = await self.event_creation_handler.create_event(
creator, creator,
event_dict, event_dict,
prev_event_ids=prev_event, prev_event_ids=prev_event,
depth=depth, depth=depth,
state_map=state_map, state_map=state_map,
for_batch=for_batch, for_batch=for_batch,
current_state_group=current_state_group,
) )
depth += 1 depth += 1
prev_event = [new_event.event_id] prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id state_map[(new_event.type, new_event.state_key)] = new_event.event_id
return new_event, new_context return new_event, new_unpersisted_context
try: try:
config = self._presets_dict[preset_config] config = self._presets_dict[preset_config]
@ -1134,10 +1141,10 @@ class RoomCreationHandler:
) )
creation_content.update({"creator": creator_id}) creation_content.update({"creator": creator_id})
creation_event, creation_context = await create_event( creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False EventTypes.Create, creation_content, False
) )
creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member) logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event( ev = await self.event_creation_handler.handle_new_client_event(
requester=creator, requester=creator,
@ -1181,7 +1188,6 @@ class RoomCreationHandler:
power_event, power_context = await create_event( power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True EventTypes.PowerLevels, pl_content, True
) )
current_state_group = power_context._state_group
events_to_send.append((power_event, power_context)) events_to_send.append((power_event, power_context))
else: else:
power_level_content: JsonDict = { power_level_content: JsonDict = {
@ -1230,14 +1236,12 @@ class RoomCreationHandler:
power_level_content, power_level_content,
True, True,
) )
current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context)) events_to_send.append((pl_event, pl_context))
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event( room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
) )
current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context)) events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state: if (EventTypes.JoinRules, "") not in initial_state:
@ -1246,7 +1250,6 @@ class RoomCreationHandler:
{"join_rule": config["join_rules"]}, {"join_rule": config["join_rules"]},
True, True,
) )
current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context)) events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state: if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@ -1255,7 +1258,6 @@ class RoomCreationHandler:
{"history_visibility": config["history_visibility"]}, {"history_visibility": config["history_visibility"]},
True, True,
) )
current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context)) events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]: if config["guest_can_join"]:
@ -1265,14 +1267,12 @@ class RoomCreationHandler:
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True, True,
) )
current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context)) events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
event, context = await create_event( event, context = await create_event(
etype, content, True, state_key=state_key etype, content, True, state_key=state_key
) )
current_state_group = context._state_group
events_to_send.append((event, context)) events_to_send.append((event, context))
if config["encrypted"]: if config["encrypted"]:
@ -1284,9 +1284,16 @@ class RoomCreationHandler:
) )
events_to_send.append((encryption_event, encryption_context)) events_to_send.append((encryption_event, encryption_context))
datastore = self.hs.get_datastores().state
events_and_context = (
await UnpersistedEventContext.batch_persist_unpersisted_contexts(
events_to_send, room_id, current_state_group, datastore
)
)
last_event = await self.event_creation_handler.handle_new_client_event( last_event = await self.event_creation_handler.handle_new_client_event(
creator, creator,
events_to_send, events_and_context,
ignore_shadow_ban=True, ignore_shadow_ban=True,
ratelimit=False, ratelimit=False,
) )

View file

@ -327,7 +327,7 @@ class RoomBatchHandler:
# Mark all events as historical # Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
event, context = await self.event_creation_handler.create_event( event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service( await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service ev["sender"], app_service_requester.app_service
), ),
@ -345,7 +345,7 @@ class RoomBatchHandler:
historical=True, historical=True,
depth=inherited_depth, depth=inherited_depth,
) )
context = await unpersisted_context.persist(event)
assert context._state_group assert context._state_group
# Normally this is done when persisting the event but we have to # Normally this is done when persisting the event but we have to

View file

@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, context = await self.event_creation_handler.create_event( (
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier, outlier=outlier,
historical=historical, historical=historical,
) )
context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids( prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)]) StateFilter.from_types([(EventTypes.Member, None)])
) )
@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, context = await self.event_creation_handler.create_event( (
event,
unpersisted_context,
) = await self.event_creation_handler.create_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
auth_event_ids=auth_event_ids, auth_event_ids=auth_event_ids,
outlier=True, outlier=True,
) )
context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
result_event = ( result_event = (

View file

@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr import attr
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types, fetched_keys=non_member_types,
) )
async def store_state_deltas_for_batched(
self,
events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
room_id: str,
prev_group: int,
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
"""Generate and store state deltas for a group of events and contexts created to be
batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
Args:
events_and_context: the events to generate and store a state groups for
and their associated contexts
room_id: the id of the room the events were created for
prev_group: the state group of the last event persisted before the batched events
were created
"""
def insert_deltas_group_txn(
txn: LoggingTransaction,
events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
prev_group: int,
) -> List[Tuple[EventBase, UnpersistedEventContext]]:
"""Generate and store state groups for the provided events and contexts.
Requires that we have the state as a delta from the last persisted state group.
Returns:
A list of state groups
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)
num_state_groups = sum(
1 for event, _ in events_and_context if event.is_state()
)
state_groups = self._state_group_seq_gen.get_next_mult_txn(
txn, num_state_groups
)
sg_before = prev_group
state_group_iter = iter(state_groups)
for event, context in events_and_context:
if not event.is_state():
context.state_group_after_event = sg_before
context.state_group_before_event = sg_before
continue
sg_after = next(state_group_iter)
context.state_group_after_event = sg_after
context.state_group_before_event = sg_before
context.state_delta_due_to_event = {
(event.type, event.state_key): event.event_id
}
sg_before = sg_after
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups",
keys=("id", "room_id", "event_id"),
values=[
(context.state_group_after_event, room_id, event.event_id)
for event, context in events_and_context
if event.is_state()
],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_group_edges",
keys=("state_group", "prev_state_group"),
values=[
(
context.state_group_after_event,
context.state_group_before_event,
)
for event, context in events_and_context
if event.is_state()
],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(
context.state_group_after_event,
room_id,
key[0],
key[1],
state_id,
)
for event, context in events_and_context
if context.state_delta_due_to_event is not None
for key, state_id in context.state_delta_due_to_event.items()
],
)
return events_and_context
return await self.db_pool.runInteraction(
"store_state_deltas_for_batched.insert_deltas_group",
insert_deltas_group_txn,
events_and_context,
prev_group,
)
async def store_state_group( async def store_state_group(
self, self,
event_id: str, event_id: str,

View file

@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
return memberEvent, memberEventContext return memberEvent, memberEventContext
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: def _create_duplicate_event(
self, txn_id: str
) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event with the given transaction ID. All events produced """Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates. by this method will be considered duplicates.
""" """
@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random" txn_id = "something_suitably_random"
event1, context = self._create_duplicate_event(txn_id) event1, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success( ret_event1 = self.get_success(
self.handler.handle_new_client_event( self.handler.handle_new_client_event(
@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id) self.assertEqual(event1.event_id, ret_event1.event_id)
event2, context = self._create_duplicate_event(txn_id) event2, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works, # We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events. # so we want to make sure we test with different events.
@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right # Let's test that calling `persist_event` directly also does the right
# thing. # thing.
event3, context = self._create_duplicate_event(txn_id) event3, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event3))
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success( ret_event3, event_pos3, _ = self.get_success(
@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right # Let's test that calling `persist_events` directly also does the right
# thing. # thing.
event4, context = self._create_duplicate_event(txn_id) event4, unpersisted_context = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success( events, _ = self.get_success(
@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random" txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time # Create two duplicate events to persist at the same time
event1, context1 = self._create_duplicate_event(txn_id) event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
event2, context2 = self._create_duplicate_event(txn_id) context1 = self.get_success(unpersisted_context1.persist(event1))
event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with # Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id) self.assertNotEqual(event1.event_id, event2.event_id)

View file

@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter. # Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler() event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter) requester = create_requester(inviter)
event, context = self.get_success( event, unpersisted_context = self.get_success(
event_creation_handler.create_event( event_creation_handler.create_event(
requester, requester,
{ {
@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
}, },
) )
) )
context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(
event_creation_handler.handle_new_client_event( event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)] requester, events_and_context=[(event, context)]

View file

@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy # Create a new message event, and try to evaluate it under the dodgy
# power level event. # power level event.
event, context = self.get_success( event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
prev_event_ids=[pl_event_id], prev_event_ids=[pl_event_id],
) )
) )
context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs) bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise # should not raise
@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config""" """Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification. # Create a new message event which should cause a notification.
event, context = self.get_success( event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
}, },
) )
) )
context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs) bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Mock the method which calculates push rules -- we do this instead of # Mock the method which calculates push rules -- we do this instead of
@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool: ) -> bool:
"""Returns true iff the `mentions` trigger an event push action.""" """Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification. # Create a new message event which should cause a notification.
event, context = self.get_success( event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
}, },
) )
) )
context = self.get_success(unpersisted_context.persist(event))
# Execute the push rule machinery. # Execute the push rule machinery.
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs) bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation. # Create & persist an event to use as the parent of the relation.
event, context = self.get_success( event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
}, },
) )
) )
context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(
self.event_creation_handler.handle_new_client_event( self.event_creation_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)] self.requester, events_and_context=[(event, context)]

View file

@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(33, channel.resource_usage.db_txn_count) self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None: def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id # POST with initial_state config key, expect new room id
@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body) self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None assert channel.resource_usage is not None
self.assertEqual(36, channel.resource_usage.db_txn_count) self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None: def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id # POST with visibility config key, expect new room id

View file

@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id) self.store.get_prev_events_for_room(room_id)
) )
event, context = self.get_success( event, unpersisted_context = self.get_success(
event_handler.create_event( event_handler.create_event(
self.requester, self.requester,
{ {
@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(
event_handler.handle_new_client_event( event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)] self.requester, events_and_context=[(event, context)]
@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None assert state_ids1 is not None
state1 = set(state_ids1.values()) state1 = set(state_ids1.values())
event, context = self.get_success( event, unpersisted_context = self.get_success(
event_handler.create_event( event_handler.create_event(
self.requester, self.requester,
{ {
@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
) )
) )
context = self.get_success(unpersisted_context.persist(event))
self.get_success( self.get_success(
event_handler.handle_new_client_event( event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)] self.requester, events_and_context=[(event, context)]

View file

@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
def test_batched_state_group_storing(self) -> None:
creation_event = self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, "", {}
)
state_to_event = self.get_success(
self.storage.state.get_state_groups(
self.room.to_string(), [creation_event.event_id]
)
)
current_state_group = list(state_to_event.keys())[0]
# create some unpersisted events and event contexts to store against room
events_and_context = []
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Name,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"name": "first rename of room"},
},
)
event1, unpersisted_context1 = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
events_and_context.append((event1, unpersisted_context1))
builder2 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.JoinRules,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"join_rule": "private"},
},
)
event2, unpersisted_context2 = self.get_success(
self.event_creation_handler.create_new_client_event(builder2)
)
events_and_context.append((event2, unpersisted_context2))
builder3 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Message,
"sender": self.u_alice.to_string(),
"room_id": self.room.to_string(),
"content": {"body": "hello from event 3", "msgtype": "m.text"},
},
)
event3, unpersisted_context3 = self.get_success(
self.event_creation_handler.create_new_client_event(builder3)
)
events_and_context.append((event3, unpersisted_context3))
builder4 = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.JoinRules,
"sender": self.u_alice.to_string(),
"state_key": "",
"room_id": self.room.to_string(),
"content": {"join_rule": "public"},
},
)
event4, unpersisted_context4 = self.get_success(
self.event_creation_handler.create_new_client_event(builder4)
)
events_and_context.append((event4, unpersisted_context4))
processed_events_and_context = self.get_success(
self.hs.get_datastores().state.store_state_deltas_for_batched(
events_and_context, self.room.to_string(), current_state_group
)
)
# check that only state events are in state_groups, and all state events are in state_groups
res = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups",
keyvalues=None,
retcols=("event_id",),
)
)
events = []
for result in res:
self.assertNotIn(event3.event_id, result)
events.append(result.get("event_id"))
for event, _ in processed_events_and_context:
if event.is_state():
self.assertIn(event.event_id, events)
# check that each unique state has state group in state_groups_state and that the
# type/state key is correct, and check that each state event's state group
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
state = self.get_success(
self.store.db_pool.simple_select_list(
table="state_groups_state",
keyvalues={"state_group": context.state_group_after_event},
retcols=("type", "state_key"),
)
)
self.assertEqual(event.type, state[0].get("type"))
self.assertEqual(event.state_key, state[0].get("state_key"))
groups = self.get_success(
self.store.db_pool.simple_select_list(
table="state_group_edges",
keyvalues={"state_group": str(context.state_group_after_event)},
retcols=("*",),
)
)
self.assertEqual(
context.state_group_before_event, groups[0].get("prev_state_group")
)

View file

@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler() event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user) requester = create_requester(user)
event, context = self.get_success( event, unpersisted_context = self.get_success(
event_creator.create_event( event_creator.create_event(
requester, requester,
{ {
@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase):
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
) )
) )
context = self.get_success(unpersisted_context.persist(event))
if soft_failed: if soft_failed:
event.internal_metadata.soft_failed = True event.internal_metadata.soft_failed = True