mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-13 21:33:20 +01:00
Faster joins: parse msc3706 fields in send_join response (#12011)
Part of my work on #11249: add code to handle the new fields added in MSC3706.
This commit is contained in:
parent
6127c4b9f1
commit
da0e9f8efd
6 changed files with 140 additions and 33 deletions
1
changelog.d/12011.misc
Normal file
1
changelog.d/12011.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Preparation for faster-room-join work: parse msc3706 fields in send_join response.
|
|
@ -64,3 +64,7 @@ class ExperimentalConfig(Config):
|
||||||
|
|
||||||
# MSC3706 (server-side support for partial state in /send_join responses)
|
# MSC3706 (server-side support for partial state in /send_join responses)
|
||||||
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
|
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
|
||||||
|
|
||||||
|
# experimental support for faster joins over federation (msc2775, msc3706)
|
||||||
|
# requires a target server with msc3706_enabled enabled.
|
||||||
|
self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
|
# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
|
||||||
# Copyright 2020 Sorunome
|
# Copyright 2020 Sorunome
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -89,6 +89,12 @@ class SendJoinResult:
|
||||||
state: List[EventBase]
|
state: List[EventBase]
|
||||||
auth_chain: List[EventBase]
|
auth_chain: List[EventBase]
|
||||||
|
|
||||||
|
# True if 'state' elides non-critical membership events
|
||||||
|
partial_state: bool
|
||||||
|
|
||||||
|
# if 'partial_state' is set, a list of the servers in the room (otherwise empty)
|
||||||
|
servers_in_room: List[str]
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
@ -876,11 +882,18 @@ class FederationClient(FederationBase):
|
||||||
% (auth_chain_create_events,)
|
% (auth_chain_create_events,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response.partial_state and not response.servers_in_room:
|
||||||
|
raise InvalidResponseError(
|
||||||
|
"partial_state was set, but no servers were listed in the room"
|
||||||
|
)
|
||||||
|
|
||||||
return SendJoinResult(
|
return SendJoinResult(
|
||||||
event=event,
|
event=event,
|
||||||
state=signed_state,
|
state=signed_state,
|
||||||
auth_chain=signed_auth,
|
auth_chain=signed_auth,
|
||||||
origin=destination,
|
origin=destination,
|
||||||
|
partial_state=response.partial_state,
|
||||||
|
servers_in_room=response.servers_in_room or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
# MSC3083 defines additional error codes for room joins.
|
# MSC3083 defines additional error codes for room joins.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
|
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
|
||||||
# Copyright 2020 Sorunome
|
# Copyright 2020 Sorunome
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -60,6 +60,7 @@ class TransportLayerClient:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.client = hs.get_federation_http_client()
|
self.client = hs.get_federation_http_client()
|
||||||
|
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
|
||||||
|
|
||||||
async def get_room_state_ids(
|
async def get_room_state_ids(
|
||||||
self, destination: str, room_id: str, event_id: str
|
self, destination: str, room_id: str, event_id: str
|
||||||
|
@ -336,10 +337,15 @@ class TransportLayerClient:
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
) -> "SendJoinResponse":
|
) -> "SendJoinResponse":
|
||||||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||||
|
query_params: Dict[str, str] = {}
|
||||||
|
if self._faster_joins_enabled:
|
||||||
|
# lazy-load state on join
|
||||||
|
query_params["org.matrix.msc3706.partial_state"] = "true"
|
||||||
|
|
||||||
return await self.client.put_json(
|
return await self.client.put_json(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
path=path,
|
path=path,
|
||||||
|
args=query_params,
|
||||||
data=content,
|
data=content,
|
||||||
parser=SendJoinParser(room_version, v1_api=False),
|
parser=SendJoinParser(room_version, v1_api=False),
|
||||||
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
|
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
|
||||||
|
@ -1271,6 +1277,12 @@ class SendJoinResponse:
|
||||||
# "event" is not included in the response.
|
# "event" is not included in the response.
|
||||||
event: Optional[EventBase] = None
|
event: Optional[EventBase] = None
|
||||||
|
|
||||||
|
# The room state is incomplete
|
||||||
|
partial_state: bool = False
|
||||||
|
|
||||||
|
# List of servers in the room
|
||||||
|
servers_in_room: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
@ijson.coroutine
|
@ijson.coroutine
|
||||||
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
|
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
|
||||||
|
@ -1297,6 +1309,32 @@ def _event_list_parser(
|
||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
@ijson.coroutine
|
||||||
|
def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
|
||||||
|
"""Helper function for use with `ijson.items_coro`
|
||||||
|
|
||||||
|
Parses the partial_state field in send_join responses
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
val = yield
|
||||||
|
if not isinstance(val, bool):
|
||||||
|
raise TypeError("partial_state must be a boolean")
|
||||||
|
response.partial_state = val
|
||||||
|
|
||||||
|
|
||||||
|
@ijson.coroutine
|
||||||
|
def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
|
||||||
|
"""Helper function for use with `ijson.items_coro`
|
||||||
|
|
||||||
|
Parses the servers_in_room field in send_join responses
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
val = yield
|
||||||
|
if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
|
||||||
|
raise TypeError("servers_in_room must be a list of strings")
|
||||||
|
response.servers_in_room = val
|
||||||
|
|
||||||
|
|
||||||
class SendJoinParser(ByteParser[SendJoinResponse]):
|
class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||||
"""A parser for the response to `/send_join` requests.
|
"""A parser for the response to `/send_join` requests.
|
||||||
|
|
||||||
|
@ -1308,44 +1346,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
||||||
CONTENT_TYPE = "application/json"
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||||
self._response = SendJoinResponse([], [], {})
|
self._response = SendJoinResponse([], [], event_dict={})
|
||||||
self._room_version = room_version
|
self._room_version = room_version
|
||||||
|
self._coros = []
|
||||||
|
|
||||||
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
||||||
# prefixing with `item.*`.
|
# prefixing with `item.*`.
|
||||||
prefix = "item." if v1_api else ""
|
prefix = "item." if v1_api else ""
|
||||||
|
|
||||||
self._coro_state = ijson.items_coro(
|
self._coros = [
|
||||||
_event_list_parser(room_version, self._response.state),
|
ijson.items_coro(
|
||||||
prefix + "state.item",
|
_event_list_parser(room_version, self._response.state),
|
||||||
use_float=True,
|
prefix + "state.item",
|
||||||
)
|
use_float=True,
|
||||||
self._coro_auth = ijson.items_coro(
|
),
|
||||||
_event_list_parser(room_version, self._response.auth_events),
|
ijson.items_coro(
|
||||||
prefix + "auth_chain.item",
|
_event_list_parser(room_version, self._response.auth_events),
|
||||||
use_float=True,
|
prefix + "auth_chain.item",
|
||||||
)
|
use_float=True,
|
||||||
# TODO Remove the unstable prefix when servers have updated.
|
),
|
||||||
#
|
# TODO Remove the unstable prefix when servers have updated.
|
||||||
# By re-using the same event dictionary this will cause the parsing of
|
#
|
||||||
# org.matrix.msc3083.v2.event and event to stomp over each other.
|
# By re-using the same event dictionary this will cause the parsing of
|
||||||
# Generally this should be fine.
|
# org.matrix.msc3083.v2.event and event to stomp over each other.
|
||||||
self._coro_unstable_event = ijson.kvitems_coro(
|
# Generally this should be fine.
|
||||||
_event_parser(self._response.event_dict),
|
ijson.kvitems_coro(
|
||||||
prefix + "org.matrix.msc3083.v2.event",
|
_event_parser(self._response.event_dict),
|
||||||
use_float=True,
|
prefix + "org.matrix.msc3083.v2.event",
|
||||||
)
|
use_float=True,
|
||||||
self._coro_event = ijson.kvitems_coro(
|
),
|
||||||
_event_parser(self._response.event_dict),
|
ijson.kvitems_coro(
|
||||||
prefix + "event",
|
_event_parser(self._response.event_dict),
|
||||||
use_float=True,
|
prefix + "event",
|
||||||
)
|
use_float=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if not v1_api:
|
||||||
|
self._coros.append(
|
||||||
|
ijson.items_coro(
|
||||||
|
_partial_state_parser(self._response),
|
||||||
|
"org.matrix.msc3706.partial_state",
|
||||||
|
use_float="True",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._coros.append(
|
||||||
|
ijson.items_coro(
|
||||||
|
_servers_in_room_parser(self._response),
|
||||||
|
"org.matrix.msc3706.servers_in_room",
|
||||||
|
use_float="True",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def write(self, data: bytes) -> int:
|
def write(self, data: bytes) -> int:
|
||||||
self._coro_state.send(data)
|
for c in self._coros:
|
||||||
self._coro_auth.send(data)
|
c.send(data)
|
||||||
self._coro_unstable_event.send(data)
|
|
||||||
self._coro_event.send(data)
|
|
||||||
|
|
||||||
return len(data)
|
return len(data)
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,8 @@ REQUIREMENTS = [
|
||||||
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
# We enforce that we have a `cryptography` version that bundles an `openssl`
|
||||||
# with the latest security patches.
|
# with the latest security patches.
|
||||||
"cryptography>=3.4.7",
|
"cryptography>=3.4.7",
|
||||||
"ijson>=3.1",
|
# ijson 3.1.4 fixes a bug with "." in property names
|
||||||
|
"ijson>=3.1.4",
|
||||||
"matrix-common~=1.1.0",
|
"matrix-common~=1.1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase):
|
||||||
self.assertEqual(len(parsed_response.state), 1, parsed_response)
|
self.assertEqual(len(parsed_response.state), 1, parsed_response)
|
||||||
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
|
self.assertEqual(parsed_response.event_dict, {}, parsed_response)
|
||||||
self.assertIsNone(parsed_response.event, parsed_response)
|
self.assertIsNone(parsed_response.event, parsed_response)
|
||||||
|
self.assertFalse(parsed_response.partial_state, parsed_response)
|
||||||
|
self.assertEqual(parsed_response.servers_in_room, None, parsed_response)
|
||||||
|
|
||||||
|
def test_partial_state(self) -> None:
|
||||||
|
"""Check that the partial_state flag is correctly parsed"""
|
||||||
|
parser = SendJoinParser(RoomVersions.V1, False)
|
||||||
|
response = {
|
||||||
|
"org.matrix.msc3706.partial_state": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
serialised_response = json.dumps(response).encode()
|
||||||
|
|
||||||
|
# Send data to the parser
|
||||||
|
parser.write(serialised_response)
|
||||||
|
|
||||||
|
# Retrieve and check the parsed SendJoinResponse
|
||||||
|
parsed_response = parser.finish()
|
||||||
|
self.assertTrue(parsed_response.partial_state)
|
||||||
|
|
||||||
|
def test_servers_in_room(self) -> None:
|
||||||
|
"""Check that the servers_in_room field is correctly parsed"""
|
||||||
|
parser = SendJoinParser(RoomVersions.V1, False)
|
||||||
|
response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
|
||||||
|
|
||||||
|
serialised_response = json.dumps(response).encode()
|
||||||
|
|
||||||
|
# Send data to the parser
|
||||||
|
parser.write(serialised_response)
|
||||||
|
|
||||||
|
# Retrieve and check the parsed SendJoinResponse
|
||||||
|
parsed_response = parser.finish()
|
||||||
|
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
|
||||||
|
|
Loading…
Reference in a new issue