synapse/tests/replication/slave/storage/test_events.py
Dan Callahan 1d5f0e3529
Bump black configuration to target py36 (#9781)
Signed-off-by: Dan Callahan <danc@element.io>
2021-04-13 10:41:34 +01:00

390 lines
13 KiB
Python

# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:test"
USER_ID_2 = "@bright:test"
OUTLIER = {"outlier": True}
ROOM_ID = "!room:test"
logger = logging.getLogger(__name__)
def dict_equals(self, other):
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them
def patch__eq__(cls):
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals
def unpatch():
if eq is not None:
cls.__eq__ = eq
return unpatch
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedEventStore
def setUp(self):
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
return super().setUp()
def prepare(self, *args, **kwargs):
super().prepare(*args, **kwargs)
self.get_success(
self.master_store.store_room(
ROOM_ID,
USER_ID,
is_public=False,
room_version=RoomVersions.V1,
)
)
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
def test_get_latest_event_ids_in_room(self):
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
join = self.persist(
type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})],
)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
def test_redactions(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
self.check("get_event", [msg.event_id], redacted)
def test_backfilled_redactions(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
self.check("get_event", [msg.event_id], redacted)
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
"get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(
ROOM_ID,
USER_ID,
"invite",
event.event_id,
event.internal_metadata.stream_ordering,
)
],
)
def test_push_actions_for_user(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.join", key=USER_ID, membership="join")
self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
)
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])],
)
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
)
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
"""Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
by rows in the events stream
"""
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.replicate()
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
j2 = self.persist(
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
expected_pos = PersistedEventPosition(
"master", j2.internal_metadata.stream_ordering
)
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
{(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
"""Check that current_state invalidation happens correctly with multiple events
in the persistence batch.
This test attempts to reproduce a race condition between the event persistence
loop and a worker-based Sync handler.
The problem occurred when the master persisted several events in one batch. It
only updates the current_state at the end of each batch, so the obvious thing
to do is then to issue a current_state_delta stream update corresponding to the
last stream_id in the batch.
However, that raises the possibility that a worker will see the replication
notification for a join event before the current_state caches are invalidated.
The test involves:
* creating a join and a message event for a user, and persisting them in the
same batch
* controlling the replication stream so that updates are sent gradually
* between each bunch of replication updates, check that we see a consistent
snapshot of the state.
"""
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
self.replicate()
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
repl_transport = self._server_transport
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
logger.info("----- build test events ------")
j2, j2ctx = self.build_event(
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
self.get_success(
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
)
self.replicate()
event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store
current_token = self.get_success(event_source.get_current_key())
# gradually stream out the replication
while repl_transport.buffer:
logger.info("------ flush ------")
repl_transport.flush(30)
self.pump(0)
prev_token = current_token
current_token = self.get_success(event_source.get_current_key())
# attempt to replicate the behaviour of the sync handler.
#
# First, we get a list of the rooms we are joined to
joined_rooms = self.get_success(
self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
)
# Then, we get a list of the events since the last sync
membership_changes = self.get_success(
self.slaved_store.get_membership_changes_for_user(
USER_ID_2, prev_token, current_token
)
)
logger.info(
"%s->%s: joined_rooms=%r membership_changes=%r",
prev_token,
current_token,
joined_rooms,
membership_changes,
)
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
expected_pos = PersistedEventPosition(
"master", j2.internal_metadata.stream_ordering
)
self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
def persist(self, backfill=False, **kwargs):
"""
Returns:
synapse.events.FrozenEvent: The event that was persisted.
"""
event, context = self.build_event(**kwargs)
if backfill:
self.get_success(
self.storage.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
self.get_success(self.storage.persistence.persist_event(event, context))
return event
def build_event(
self,
sender=USER_ID,
room_id=ROOM_ID,
type="m.room.message",
key=None,
internal: Optional[dict] = None,
depth=None,
prev_events: Optional[list] = None,
auth_events: Optional[list] = None,
prev_state: Optional[list] = None,
redacts=None,
push_actions: Iterable = frozenset(),
**content,
):
prev_events = prev_events or []
auth_events = auth_events or []
prev_state = prev_state or []
if depth is None:
depth = self.event_id
if not prev_events:
latest_event_ids = self.get_success(
self.master_store.get_latest_event_ids_in_room(room_id)
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = {
"sender": sender,
"type": type,
"content": content,
"event_id": "$%d:blue" % (self.event_id,),
"room_id": room_id,
"depth": depth,
"origin_server_ts": self.event_id,
"prev_events": prev_events,
"auth_events": auth_events,
}
if key is not None:
event_dict["state_key"] = key
event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
self.event_id += 1
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
self.get_success(
self.master_store.add_push_actions_to_staging(
event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
)
)
return event, context