0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 19:23:53 +01:00

Additional tests for third-party event rules (#8468)

* Optimise and test state fetching for 3p event rules

Getting all the events at once is much more efficient than getting them
individually

* Test that 3p event rules can modify events
This commit is contained in:
Richard van der Hoff 2020-10-06 16:31:31 +01:00 committed by GitHub
parent 9c0b168cff
commit a024461130
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 18 deletions

1
changelog.d/8468.misc Normal file
View file

@ -0,0 +1 @@
Additional testing for `ThirdPartyEventRules`.

View file

@ -61,12 +61,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database. # Retrieve the state events from the database.
state_events = {} events = await self.store.get_events(prev_state_ids.values())
for key, event_id in prev_state_ids.items(): state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
state_events[key] = await self.store.get_event(event_id, allow_none=True)
ret = await self.third_party_rules.check_event_allowed(event, state_events) # The module can modify the event slightly if it wants, but caution should be
return ret # exercised, and it's likely to go very wrong if applied to events received over
# federation.
return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool

View file

@ -12,33 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading
from mock import Mock
from synapse.events import EventBase
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.types import Requester from synapse.types import Requester, StateMap
from tests import unittest from tests import unittest
thread_local = threading.local()
class ThirdPartyRulesTestModule: class ThirdPartyRulesTestModule:
def __init__(self, config, *args, **kwargs): def __init__(self, config, module_api):
pass # keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
async def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool
): ):
return True return True
async def check_event_allowed(self, event, context): async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
if event.type == "foo.bar.forbidden": return True
return False
else:
return True
@staticmethod @staticmethod
def parse_config(config): def parse_config(config):
return config return config
def current_rules_module() -> ThirdPartyRulesTestModule:
return thread_local.rules_module
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
@ -46,15 +56,13 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def default_config(self):
config = self.default_config() config = super().default_config()
config["third_party_event_rules"] = { config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule", "module": __name__ + ".ThirdPartyRulesTestModule",
"config": {}, "config": {},
} }
return config
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests # Create a user and room to play with during the tests
@ -67,6 +75,14 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one """Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent. can be sent.
""" """
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
return ev.type != "foo.bar.forbidden"
callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id, "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
@ -76,6 +92,16 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
callback.assert_called_once()
# there should be various state events in the state arg: do some basic checks
state_arg = callback.call_args[0][1]
for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
self.assertIn(k, state_arg)
ev = state_arg[k]
self.assertEqual(ev.type, k[0])
self.assertEqual(ev.state_key, k[1])
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id, "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
@ -84,3 +110,35 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_modify_event(self):
"""Tests that the module can successfully tweak an event before it is persisted.
"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
return True
current_rules_module().check_event_allowed = check
# now send the event
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
request, channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"200", channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")