Tests: replace mocked Authenticator with the real thing (#11913)

If we prepopulate the test homeserver with a key for a remote homeserver, we
can make federation requests to it without having to stub out the
authenticator. This has two advantages:

 * means that what we are testing is closer to reality (ie, we now have
   complete tests for the incoming-request-authorisation flow)

 * some tests require that other objects be signed by the remote server (eg,
   the event in `/send_join`), and doing that would require a whole separate
   set of mocking out. It's much simpler just to use real keys.
This commit is contained in:
Richard van der Hoff 2022-02-11 12:06:02 +00:00 committed by GitHub
parent d36943c4df
commit c3db7a0b59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 117 additions and 44 deletions

1
changelog.d/11913.misc Normal file
View file

@ -0,0 +1 @@
Tests: replace mocked `Authenticator` with the real thing.

View file

@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
) )
# Get the room complexity # Get the room complexity
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEquals(200, channel.code) self.assertEquals(200, channel.code)
@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value # Get the room complexity again -- make sure it's our artificial value
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEquals(200, channel.code) self.assertEquals(200, channel.code)

View file

@ -113,7 +113,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token) room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join") self.inject_room_member(room_1, "@user:other.example.com", "join")
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,) "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
) )
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
@ -145,7 +145,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token) room_1 = self.helper.create_room_as(u1, tok=u1_token)
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,) "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
) )
self.assertEquals(403, channel.code, channel.result) self.assertEquals(403, channel.code, channel.result)

View file

@ -245,7 +245,7 @@ class FederationKnockingTestCase(
self.hs, room_id, user_id self.hs, room_id, user_id
) )
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "GET",
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s" "/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
% ( % (
@ -288,7 +288,7 @@ class FederationKnockingTestCase(
) )
# Send the signed knock event into the room # Send the signed knock event into the room
channel = self.make_request( channel = self.make_signed_federation_request(
"PUT", "PUT",
"/_matrix/federation/v1/send_knock/%s/%s" "/_matrix/federation/v1/send_knock/%s/%s"
% (room_id, signed_knock_event.event_id), % (room_id, signed_knock_event.event_id),

View file

@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 403 when """Test that unauthenticated requests to the public rooms directory 403 when
allow_public_rooms_over_federation is False. allow_public_rooms_over_federation is False.
""" """
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "GET",
"/_matrix/federation/v1/publicRooms", "/_matrix/federation/v1/publicRooms",
federation_auth_origin=b"example.com",
) )
self.assertEquals(403, channel.code) self.assertEquals(403, channel.code)
@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 200 when """Test that unauthenticated requests to the public rooms directory 200 when
allow_public_rooms_over_federation is True. allow_public_rooms_over_federation is True.
""" """
channel = self.make_request( channel = self.make_signed_federation_request(
"GET", "GET",
"/_matrix/federation/v1/publicRooms", "/_matrix/federation/v1/publicRooms",
federation_auth_origin=b"example.com",
) )
self.assertEquals(200, channel.code) self.assertEquals(200, channel.code)

View file

@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
# Create some users and a room to play with during the tests # Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey") self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme") self.invitee = self.register_user("invitee", "hackme")
@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def _send_event_over_federation(self) -> None: def _send_event_over_federation(self) -> None:
"""Send a dummy event over federation and check that the request succeeds.""" """Send a dummy event over federation and check that the request succeeds."""
body = { body = {
"origin": self.hs.config.server.server_name,
"origin_server_ts": self.clock.time_msec(),
"pdus": [ "pdus": [
{ {
"sender": self.user_id, "sender": self.user_id,
@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
], ],
} }
channel = self.make_request( channel = self.make_signed_federation_request(
method="PUT", method="PUT",
path="/_matrix/federation/v1/send/1", path="/_matrix/federation/v1/send/1",
content=body, content=body,
federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)

View file

@ -17,6 +17,7 @@ import gc
import hashlib import hashlib
import hmac import hmac
import inspect import inspect
import json
import logging import logging
import secrets import secrets
import time import time
@ -36,9 +37,11 @@ from typing import (
) )
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from canonicaljson import json import canonicaljson
import signedjson.key
import unpaddedbase64
from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -49,8 +52,7 @@ from twisted.web.server import Request
from synapse import events from synapse import events
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport.server import TransportLayerServer
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import ( from synapse.logging.context import (
@ -61,10 +63,10 @@ from synapse.logging.context import (
) )
from synapse.rest import RegisterServletsFunc from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils import event_injection, setup_awaitable_errors
@ -755,42 +757,116 @@ class HomeserverTestCase(TestCase):
class FederatingHomeserverTestCase(HomeserverTestCase): class FederatingHomeserverTestCase(HomeserverTestCase):
""" """
A federating homeserver that authenticates incoming requests as `other.example.com`. A federating homeserver, set up to validate incoming federation requests
""" """
OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
# poke the other server's signing key into the key store, so that we don't
# make requests for it
verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
self.get_success(
hs.get_datastore().store_server_verify_keys(
from_server=self.OTHER_SERVER_NAME,
ts_added_ms=clock.time_msec(),
verify_keys=[
(
self.OTHER_SERVER_NAME,
verify_key_id,
FetchKeyResult(
verify_key=verify_key,
valid_until_ts=clock.time_msec() + 1000,
),
)
],
)
)
def create_resource_dict(self) -> Dict[str, Resource]: def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict() d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs) d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d return d
def make_signed_federation_request(
self,
method: str,
path: str,
content: Optional[JsonDict] = None,
await_result: bool = True,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""Make an inbound signed federation request to this server
class TestTransportLayerServer(JsonResource): The request is signed as if it came from "other.example.com", which our HS
"""A test implementation of TransportLayerServer already has the keys for.
"""
authenticates incoming requests as `other.example.com`. if custom_headers is None:
""" custom_headers = []
else:
custom_headers = list(custom_headers)
def __init__(self, hs): custom_headers.append(
super().__init__(hs) (
"Authorization",
class Authenticator: _auth_header_for_request(
def authenticate_request(self, request, content): origin=self.OTHER_SERVER_NAME,
return succeed("other.example.com") destination=self.hs.hostname,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
authenticator = Authenticator() method=method,
path=path,
ratelimiter = FederationRateLimiter( content=content,
hs.get_clock(), ),
FederationRateLimitConfig( )
window_size=1,
sleep_limit=1,
sleep_delay=1,
reject_limit=1000,
concurrent=1000,
),
) )
federation_server.register_servlets(hs, self, authenticator, ratelimiter) return make_request(
self.reactor,
self.site,
method=method,
path=path,
content=content,
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
client_ip=client_ip,
)
def _auth_header_for_request(
origin: str,
destination: str,
signing_key: signedjson.key.SigningKey,
method: str,
path: str,
content: Optional[JsonDict],
) -> str:
"""Build a suitable Authorization header for an outgoing federation request"""
request_description: JsonDict = {
"method": method,
"uri": path,
"destination": destination,
"origin": origin,
}
if content is not None:
request_description["content"] = content
signature_base64 = unpaddedbase64.encode_base64(
signing_key.sign(
canonicaljson.encode_canonical_json(request_description)
).signature
)
return (
f"X-Matrix origin={origin},"
f"key={signing_key.alg}:{signing_key.version},"
f"sig={signature_base64}"
)
def override_config(extra_config): def override_config(extra_config):