# -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import json import six from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import register, relations from tests import unittest class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, ] hijack_auth = False def make_homeserver(self, reactor, clock): # We need to enable msc1849 support for aggregations config = self.default_config() config["experimental_msc1849_support_enabled"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] def test_send_relation(self): """Tests that sending a relation using the new /send_relation works creates the right shape of event. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEquals(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] request, channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, event_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assert_dict( { "type": "m.reaction", "sender": self.user_id, "content": { "m.relates_to": { "event_id": self.parent_id, "key": "👍", "rel_type": RelationTypes.ANNOTATION, } }, }, channel.json_body, ) def test_deny_membership(self): """Test that we deny relations on membership events """ channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) def test_deny_double_react(self): """Test that we deny relations on membership events """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): """Tests that calling pagination API corectly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) # We expect to get back a single pagination result, which is the full # relation event we sent above. self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, channel.json_body["chunk"][0], ) # We also expect to get the original event (the id of which is self.parent_id) self.assertEquals( channel.json_body["original_event"]["event_id"], self.parent_id ) # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( channel.json_body.get("next_batch"), six.string_types, channel.json_body ) def test_repeated_paginate_relations(self): """Test that if we paginate using a limit and tokens then we get the expected events. """ expected_event_ids = [] for _ in range(10): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) prev_token = None found_event_ids = [] for _ in range(20): from_token = "" if prev_token: from_token = "&from=" + prev_token request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" % (self.room, self.parent_id, from_token), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") self.assertNotEquals(prev_token, next_batch) prev_token = next_batch if not prev_token: break # We paginated backwards, so reverse found_event_ids.reverse() self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation_pagination_groups(self): """Test that we can paginate annotation groups correctly. """ # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] idx = 0 while len(access_tokens) < 10: user_id, token = self._create_user("test" + str(idx)) idx += 1 self.helper.join(self.room, user=user_id, tok=token) access_tokens.append(token) idx = 0 sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} for key in itertools.chain.from_iterable( itertools.repeat(key, num) for key, num in sent_groups.items() ): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key=key, access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) idx += 1 idx %= len(access_tokens) prev_token = None found_groups = {} for _ in range(20): from_token = "" if prev_token: from_token = "&from=" + prev_token request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" % (self.room, self.parent_id, from_token), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) for groups in channel.json_body["chunk"]: # We only expect reactions self.assertEqual(groups["type"], "m.reaction", channel.json_body) # We should only see each key once self.assertNotIn(groups["key"], found_groups, channel.json_body) found_groups[groups["key"]] = groups["count"] next_batch = channel.json_body.get("next_batch") self.assertNotEquals(prev_token, next_batch) prev_token = next_batch if not prev_token: break self.assertEquals(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): """Test that we can paginate within an annotation group. """ # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] idx = 0 while len(access_tokens) < 10: user_id, token = self._create_user("test" + str(idx)) idx += 1 self.helper.join(self.room, user=user_id, tok=token) access_tokens.append(token) idx = 0 expected_event_ids = [] for _ in range(10): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) idx += 1 # Also send a different type of reaction so that we test we don't see it channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) prev_token = None found_event_ids = [] encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: from_token = "&from=" + prev_token request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s" "/aggregations/%s/%s/m.reaction/%s?limit=1%s" % ( self.room, self.parent_id, RelationTypes.ANNOTATION, encoded_key, from_token, ), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) next_batch = channel.json_body.get("next_batch") self.assertNotEquals(prev_token, next_batch) prev_token = next_batch if not prev_token: break # We paginated backwards, so reverse found_event_ids.reverse() self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation(self): """Test that annotations get correctly aggregated. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body, { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, ) def test_aggregation_redactions(self): """Test that annotations get correctly aggregated after a redaction. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEquals(200, channel.code, channel.json_body) # Now lets redact one of the 'a' reactions request, channel = self.make_request( "POST", "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), access_token=self.user_token, content={}, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body, {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) def test_aggregation_must_be_annotation(self): """Test that aggregations must be annotations. """ request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" % (self.room, self.parent_id, RelationTypes.REPLACE), access_token=self.user_token, ) self.render(request) self.assertEquals(400, channel.code, channel.json_body) def test_aggregation_get_event(self): """Test that annotations and references get correctly bundled when getting the parent event. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] request, channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals( channel.json_body["unsigned"].get("m.relations"), { RelationTypes.ANNOTATION: { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, }, ) def test_edit(self): """Test that a simple edit works. """ new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) self.assertEquals(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] request, channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) m_replace_dict = relations_dict[RelationTypes.REPLACE] for key in ["event_id", "sender", "origin_server_ts"]: self.assertIn(key, m_replace_dict) self.assert_dict( {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) def test_multi_edit(self): """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. """ channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ "msgtype": "m.text", "body": "Wibble", "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) self.assertEquals(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) self.assertEquals(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ "msgtype": "m.text", "body": "Wibble", "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) self.assertEquals(200, channel.code, channel.json_body) request, channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["content"], new_body) relations_dict = channel.json_body["unsigned"].get("m.relations") self.assertIn(RelationTypes.REPLACE, relations_dict) m_replace_dict = relations_dict[RelationTypes.REPLACE] for key in ["event_id", "sender", "origin_server_ts"]: self.assertIn(key, m_replace_dict) self.assert_dict( {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. """ # Send a new event res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) original_event_id = res["event_id"] # Add a relation channel = self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=original_event_id, content={ "msgtype": "m.text", "body": "Wibble", "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) self.assertEquals(200, channel.code, channel.json_body) # Check the relation is returned request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) self.assertEquals(len(channel.json_body["chunk"]), 1) # Redact the original event request, channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), access_token=self.user_token, content="{}", ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Try to check for remaining m.replace relations request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" % (self.room, original_event_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that no relations are returned self.assertIn("chunk", channel.json_body) self.assertEquals(channel.json_body["chunk"], []) def test_aggregations_redaction_prevents_access_to_aggregations(self): """Test that annotations of an event are redacted when the original event is redacted. """ # Send a new event res = self.helper.send(self.room, body="Hello!", tok=self.user_token) original_event_id = res["event_id"] # Add a relation channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id ) self.assertEquals(200, channel.code, channel.json_body) # Redact the original request, channel = self.make_request( "PUT", "/rooms/%s/redact/%s/%s" % ( self.room, original_event_id, "test_aggregations_redaction_prevents_access_to_aggregations", ), access_token=self.user_token, content="{}", ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) # Check that aggregations returns zero request, channel = self.make_request( "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" % (self.room, original_event_id), access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) self.assertIn("chunk", channel.json_body) self.assertEquals(channel.json_body["chunk"], []) def _send_relation( self, relation_type, event_type, key=None, content={}, access_token=None, parent_id=None, ): """Helper function to send a relation pointing at `self.parent_id` Args: relation_type (str): One of `RelationTypes` event_type (str): The type of the event to create parent_id (str): The event_id this relation relates to. If None, then self.parent_id key (str|None): The aggregation key used for m.annotation relation type. content(dict|None): The content of the created event. access_token (str|None): The access token used to send the relation, defaults to `self.user_token` Returns: FakeChannel """ if not access_token: access_token = self.user_token query = "" if key: query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) original_id = parent_id if parent_id else self.parent_id request, channel = self.make_request( "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), json.dumps(content).encode("utf-8"), access_token=access_token, ) self.render(request) return channel def _create_user(self, localpart): user_id = self.register_user(localpart, "abc123") access_token = self.login(localpart, "abc123") return user_id, access_token