Use ijson to parse the response to /send_join, reducing memory usage. (#9958)

Instead of parsing the full response to `/send_join` into Python objects (which can be huge for large rooms) and *then* parsing that into events, we instead use ijson to stream parse the response directly into `EventBase` objects.
This commit is contained in:
Erik Johnston 2021-05-20 16:11:48 +01:00 committed by GitHub
parent 551d2c3f4b
commit 64887f06fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 228 additions and 59 deletions

1
changelog.d/9958.feature Normal file
View file

@ -0,0 +1 @@
Reduce memory usage when joining very large rooms over federation.

View file

@ -174,3 +174,6 @@ ignore_missing_imports = True
[mypy-pympler.*] [mypy-pympler.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-ijson.*]
ignore_missing_imports = True

View file

@ -55,6 +55,7 @@ from synapse.api.room_versions import (
) )
from synapse.events import EventBase, builder from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
@ -665,19 +666,10 @@ class FederationClient(FederationBase):
""" """
async def send_request(destination) -> Dict[str, Any]: async def send_request(destination) -> Dict[str, Any]:
content = await self._do_send_join(destination, pdu) response = await self._do_send_join(room_version, destination, pdu)
logger.debug("Got content: %s", content) state = response.state
auth_chain = response.auth_events
state = [
event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("auth_chain", [])
]
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
@ -752,11 +744,14 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: async def _do_send_join(
self, room_version: RoomVersion, destination: str, pdu: EventBase
) -> SendJoinResponse:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
return await self.transport_layer.send_join_v2( return await self.transport_layer.send_join_v2(
room_version=room_version,
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@ -771,17 +766,14 @@ class FederationClient(FederationBase):
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = await self.transport_layer.send_join_v1( return await self.transport_layer.send_join_v1(
room_version=room_version,
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
# We expect the v1 API to respond with [200, content], so we only return the
# content.
return resp[1]
async def send_invite( async def send_invite(
self, self,
destination: str, destination: str,

View file

@ -17,13 +17,19 @@ import logging
import urllib import urllib
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import attr
import ijson
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX, FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX, FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX, FEDERATION_V2_PREFIX,
) )
from synapse.events import EventBase, make_event_from_dict
from synapse.http.matrixfederationclient import ByteParser
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import JsonDict from synapse.types import JsonDict
@ -240,21 +246,36 @@ class TransportLayerClient:
return content return content
@log_function @log_function
async def send_join_v1(self, destination, room_id, event_id, content): async def send_join_v1(
self,
room_version,
destination,
room_id,
event_id,
content,
) -> "SendJoinResponse":
path = _create_v1_path("/send_join/%s/%s", room_id, event_id) path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content destination=destination,
path=path,
data=content,
parser=SendJoinParser(room_version, v1_api=True),
) )
return response return response
@log_function @log_function
async def send_join_v2(self, destination, room_id, event_id, content): async def send_join_v2(
self, room_version, destination, room_id, event_id, content
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id) path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json( response = await self.client.put_json(
destination=destination, path=path, data=content destination=destination,
path=path,
data=content,
parser=SendJoinParser(room_version, v1_api=False),
) )
return response return response
@ -1053,3 +1074,59 @@ def _create_v2_path(path, *args):
str str
""" """
return _create_path(FEDERATION_V2_PREFIX, path, *args) return _create_path(FEDERATION_V2_PREFIX, path, *args)
@attr.s(slots=True, auto_attribs=True)
class SendJoinResponse:
"""The parsed response of a `/send_join` request."""
auth_events: List[EventBase]
state: List[EventBase]
@ijson.coroutine
def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
"""Helper function for use with `ijson.items_coro` to parse an array of
events and add them to the given list.
"""
while True:
obj = yield
event = make_event_from_dict(obj, room_version)
events.append(event)
class SendJoinParser(ByteParser[SendJoinResponse]):
"""A parser for the response to `/send_join` requests.
Args:
room_version: The version of the room.
v1_api: Whether the response is in the v1 format.
"""
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [])
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
prefix = "item." if v1_api else ""
self._coro_state = ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
)
self._coro_auth = ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
)
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
return len(data)
def finish(self) -> SendJoinResponse:
return self._response

View file

@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
if self.deferred.called: if self.deferred.called:
return return
self.stream.write(data) try:
self.stream.write(data)
except Exception:
self.deferred.errback()
return
self.length += len(data) self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the # The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received # connection. dataReceived might be called again if data was received

View file

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import cgi import cgi
import codecs import codecs
import logging import logging
@ -19,13 +20,24 @@ import sys
import typing import typing
import urllib.parse import urllib.parse
from io import BytesIO, StringIO from io import BytesIO, StringIO
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import (
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
import attr import attr
import treq import treq
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from prometheus_client import Counter from prometheus_client import Counter
from signedjson.sign import sign_json from signedjson.sign import sign_json
from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
@ -48,6 +60,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper, BlacklistingAgentWrapper,
BlacklistingReactorWrapper, BlacklistingReactorWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
ByteWriteable,
encode_query_args, encode_query_args,
read_body_with_max_size, read_body_with_max_size,
) )
@ -88,6 +101,27 @@ _next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]] QueryArgs = Dict[str, Union[str, List[str]]]
T = TypeVar("T")
class ByteParser(ByteWriteable, Generic[T], abc.ABC):
"""A `ByteWriteable` that has an additional `finish` function that returns
the parsed data.
"""
CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
"""The expected content type of the response, e.g. `application/json`. If
the content type doesn't match we fail the request.
"""
@abc.abstractmethod
def finish(self) -> T:
"""Called when response has finished streaming and the parser should
return the final result (or error).
"""
pass
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class MatrixFederationRequest: class MatrixFederationRequest:
method = attr.ib(type=str) method = attr.ib(type=str)
@ -148,15 +182,32 @@ class MatrixFederationRequest:
return self.json return self.json
async def _handle_json_response( class JsonParser(ByteParser[Union[JsonDict, list]]):
"""A parser that buffers the response and tries to parse it as JSON."""
CONTENT_TYPE = "application/json"
def __init__(self):
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)
def finish(self) -> Union[JsonDict, list]:
return json_decoder.decode(self._buffer.getvalue())
async def _handle_response(
reactor: IReactorTime, reactor: IReactorTime,
timeout_sec: float, timeout_sec: float,
request: MatrixFederationRequest, request: MatrixFederationRequest,
response: IResponse, response: IResponse,
start_ms: int, start_ms: int,
) -> JsonDict: parser: ByteParser[T],
) -> T:
""" """
Reads the JSON body of a response, with a timeout Reads the body of a response with a timeout and sends it to a parser
Args: Args:
reactor: twisted reactor, for the timeout reactor: twisted reactor, for the timeout
@ -164,23 +215,21 @@ async def _handle_json_response(
request: the request that triggered the response request: the request that triggered the response
response: response to the request response: response to the request
start_ms: Timestamp when request was made start_ms: Timestamp when request was made
parser: The parser for the response
Returns: Returns:
The parsed JSON response The parsed response
""" """
try:
check_content_type_is_json(response.headers)
buf = StringIO() try:
d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) check_content_type_is(response.headers, parser.CONTENT_TYPE)
d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
def parse(_len: int): length = await make_deferred_yieldable(d)
return json_decoder.decode(buf.getvalue())
d.addCallback(parse) value = parser.finish()
body = await make_deferred_yieldable(d)
except BodyExceededMaxSize as e: except BodyExceededMaxSize as e:
# The response was too big. # The response was too big.
logger.warning( logger.warning(
@ -193,9 +242,9 @@ async def _handle_json_response(
) )
raise RequestSendFailed(e, can_retry=False) from e raise RequestSendFailed(e, can_retry=False) from e
except ValueError as e: except ValueError as e:
# The JSON content was invalid. # The content was invalid.
logger.warning( logger.warning(
"{%s} [%s] Failed to parse JSON response - %s %s", "{%s} [%s] Failed to parse response - %s %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
request.method, request.method,
@ -225,16 +274,17 @@ async def _handle_json_response(
time_taken_secs = reactor.seconds() - start_ms / 1000 time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info( logger.info(
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
request.txn_id, request.txn_id,
request.destination, request.destination,
response.code, response.code,
response.phrase.decode("ascii", errors="replace"), response.phrase.decode("ascii", errors="replace"),
time_taken_secs, time_taken_secs,
length,
request.method, request.method,
request.uri.decode("ascii"), request.uri.decode("ascii"),
) )
return body return value
class BinaryIOWrapper: class BinaryIOWrapper:
@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
) )
return auth_headers return auth_headers
@overload
async def put_json( async def put_json(
self, self,
destination: str, destination: str,
@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False, ignore_backoff: bool = False,
backoff_on_404: bool = False, backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False, try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
) -> Union[JsonDict, list]: ) -> Union[JsonDict, list]:
...
@overload
async def put_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None,
) -> T:
...
async def put_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
"""Sends the specified json data using PUT """Sends the specified json data using PUT
Args: Args:
@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
of the request. Workaround for #3622 in Synapse <= v0.99.3. This of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been will be attempted before backing off if backing off has been
enabled. enabled.
parser: The parser to use to decode the response. Defaults to
parsing as JSON.
Returns: Returns:
Succeeds when we get a 2xx HTTP response. The Succeeds when we get a 2xx HTTP response. The
@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( if parser is None:
self.reactor, _sec_timeout, request, response, start_ms parser = JsonParser()
body = await _handle_response(
self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=parser,
) )
return body return body
@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( body = await _handle_response(
self.reactor, self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
_sec_timeout,
request,
response,
start_ms,
) )
return body return body
@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
) )
return body return body
@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
else: else:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
) )
return body return body
@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
return repr(e) return repr(e)
def check_content_type_is_json(headers: Headers) -> None: def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
""" """
Check that a set of HTTP headers have a Content-Type header, and that it Check that a set of HTTP headers have a Content-Type header, and that it
is application/json. is the expected value..
Args: Args:
headers: headers to check headers: headers to check
Raises: Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON RequestSendFailed: if the Content-Type header is missing or doesn't match
""" """
content_type_headers = headers.getRawHeaders(b"Content-Type") content_type_headers = headers.getRawHeaders(b"Content-Type")
@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
c_type = content_type_headers[0].decode("ascii") # only the first header c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type) val, options = cgi.parse_header(c_type)
if val != "application/json": if val != expected_content_type:
raise RequestSendFailed( raise RequestSendFailed(
RuntimeError( RuntimeError(
"Remote server sent Content-Type header of '%s', not 'application/json'" f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
% c_type,
), ),
can_retry=False, can_retry=False,
) )

View file

@ -87,6 +87,7 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl` # We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches. # with the latest security patches.
"cryptography>=3.4.7", "cryptography>=3.4.7",
"ijson>=3.0",
] ]
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {