0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 03:41:54 +01:00

Consistently use room_id from federation request body (#8776)

* Consistently use room_id from federation request body

Some federation APIs have a redundant `room_id` path param (see
https://github.com/matrix-org/matrix-doc/issues/2330). We should make sure we
consistently use either the path param or the body param, and the body param is
easier.

* Kill off some references to "context"

Once upon a time, "rooms" were known as "contexts". I think this kills of the
last references to "contexts".
This commit is contained in:
Richard van der Hoff 2020-11-19 10:05:33 +00:00 committed by Erik Johnston
parent 244bff4edd
commit 3ce2f303f1
5 changed files with 49 additions and 54 deletions

1
changelog.d/8776.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body.

View file

@ -49,6 +49,7 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
nested_logging_context, nested_logging_context,
@ -391,7 +392,7 @@ class FederationServer(FederationBase):
TRANSACTION_CONCURRENCY_LIMIT, TRANSACTION_CONCURRENCY_LIMIT,
) )
async def on_context_state_request( async def on_room_state_request(
self, origin: str, room_id: str, event_id: str self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
@ -514,11 +515,12 @@ class FederationServer(FederationBase):
return {"event": ret_pdu.get_pdu_json(time_now)} return {"event": ret_pdu.get_pdu_json(time_now)}
async def on_send_join_request( async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str self, origin: str, content: JsonDict
) -> Dict[str, Any]: ) -> Dict[str, Any]:
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
room_version = await self.store.get_room_version(room_id) assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version) pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
@ -547,12 +549,11 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
async def on_send_leave_request( async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
self, origin: str, content: JsonDict, room_id: str
) -> dict:
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
room_version = await self.store.get_room_version(room_id) assert_params_in_dict(content, ["room_id"])
room_version = await self.store.get_room_version(content["room_id"])
pdu = event_from_pdu_json(content, room_version) pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
@ -748,12 +749,8 @@ class FederationServer(FederationBase):
) )
return ret return ret
async def on_exchange_third_party_invite_request( async def on_exchange_third_party_invite_request(self, event_dict: Dict):
self, room_id: str, event_dict: Dict ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
):
ret = await self.handler.on_exchange_third_party_invite_request(
room_id, event_dict
)
return ret return ret
async def check_server_matches_acl(self, server_name: str, room_id: str): async def check_server_matches_acl(self, server_name: str, room_id: str):

View file

@ -440,13 +440,13 @@ class FederationEventServlet(BaseFederationServlet):
class FederationStateV1Servlet(BaseFederationServlet): class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?" PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given context. # This is when someone asks for all data for a given room.
async def on_GET(self, origin, content, query, context): async def on_GET(self, origin, content, query, room_id):
return await self.handler.on_context_state_request( return await self.handler.on_room_state_request(
origin, origin,
context, room_id,
parse_string_from_args(query, "event_id", None, required=False), parse_string_from_args(query, "event_id", None, required=False),
) )
@ -463,16 +463,16 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?" PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, context): async def on_GET(self, origin, content, query, room_id):
versions = [x.decode("ascii") for x in query[b"v"]] versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None) limit = parse_integer_from_args(query, "limit", None)
if not limit: if not limit:
return 400, {"error": "Did not include limit param"} return 400, {"error": "Did not include limit param"}
return await self.handler.on_backfill_request(origin, context, versions, limit) return await self.handler.on_backfill_request(origin, room_id, versions, limit)
class FederationQueryServlet(BaseFederationServlet): class FederationQueryServlet(BaseFederationServlet):
@ -487,9 +487,9 @@ class FederationQueryServlet(BaseFederationServlet):
class FederationMakeJoinServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, context, user_id): async def on_GET(self, origin, _content, query, room_id, user_id):
""" """
Args: Args:
origin (unicode): The authenticated server_name of the calling server origin (unicode): The authenticated server_name of the calling server
@ -511,16 +511,16 @@ class FederationMakeJoinServlet(BaseFederationServlet):
supported_versions = ["1"] supported_versions = ["1"]
content = await self.handler.on_make_join_request( content = await self.handler.on_make_join_request(
origin, context, user_id, supported_versions=supported_versions origin, room_id, user_id, supported_versions=supported_versions
) )
return 200, content return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet): class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)" PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, context, user_id): async def on_GET(self, origin, content, query, room_id, user_id):
content = await self.handler.on_make_leave_request(origin, context, user_id) content = await self.handler.on_make_leave_request(origin, room_id, user_id)
return 200, content return 200, content
@ -528,7 +528,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id) content = await self.handler.on_send_leave_request(origin, content)
return 200, (200, content) return 200, (200, content)
@ -538,43 +538,43 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, room_id, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id) content = await self.handler.on_send_leave_request(origin, content)
return 200, content return 200, content
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, context, event_id): async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, context, event_id) return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServlet): class FederationV1SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that context/event_id parsed from path actually # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content # match those given in content
content = await self.handler.on_send_join_request(origin, content, context) content = await self.handler.on_send_join_request(origin, content)
return 200, (200, content) return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet): class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that context/event_id parsed from path actually # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content # match those given in content
content = await self.handler.on_send_join_request(origin, content, context) content = await self.handler.on_send_join_request(origin, content)
return 200, content return 200, content
class FederationV1InviteServlet(BaseFederationServlet): class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or # We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the # v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing # state resolution algorithm, and we don't use that for processing
@ -589,12 +589,12 @@ class FederationV1InviteServlet(BaseFederationServlet):
class FederationV2InviteServlet(BaseFederationServlet): class FederationV2InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id): async def on_PUT(self, origin, content, query, room_id, event_id):
# TODO(paul): assert that context/event_id parsed from path actually # TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content # match those given in content
room_version = content["room_version"] room_version = content["room_version"]
@ -616,9 +616,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id): async def on_PUT(self, origin, content, query, room_id):
content = await self.handler.on_exchange_third_party_invite_request( content = await self.handler.on_exchange_third_party_invite_request(content)
room_id, content
)
return 200, content return 200, content

View file

@ -55,6 +55,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
nested_logging_context, nested_logging_context,
@ -2686,7 +2687,7 @@ class FederationHandler(BaseHandler):
) )
async def on_exchange_third_party_invite_request( async def on_exchange_third_party_invite_request(
self, room_id: str, event_dict: JsonDict self, event_dict: JsonDict
) -> None: ) -> None:
"""Handle an exchange_third_party_invite request from a remote server """Handle an exchange_third_party_invite request from a remote server
@ -2694,12 +2695,11 @@ class FederationHandler(BaseHandler):
into a normal m.room.member invite. into a normal m.room.member invite.
Args: Args:
room_id: The ID of the room. event_dict: Dictionary containing the event body.
event_dict (dict[str, Any]): Dictionary containing the event body.
""" """
room_version = await self.store.get_room_version_id(room_id) assert_params_in_dict(event_dict, ["room_id"])
room_version = await self.store.get_room_version_id(event_dict["room_id"])
# NB: event_dict has a particular specced format we might need to fudge # NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much. # if we change event formats too much.

View file

@ -59,7 +59,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
) )
d = self.handler.on_exchange_third_party_invite_request( d = self.handler.on_exchange_third_party_invite_request(
room_id=room_id,
event_dict={ event_dict={
"type": EventTypes.Member, "type": EventTypes.Member,
"room_id": room_id, "room_id": room_id,