Respect the @cancellable flag for RestServlets and BaseFederationServlets (#12699)

Both `RestServlet`s and `BaseFederationServlet`s register their handlers
with `HttpServer.register_paths` / `JsonResource.register_paths`. Update
`JsonResource` to respect the `@cancellable` flag on handlers registered
in this way.

Although `ReplicationEndpoint` also registers itself using
`register_paths`, it does not pass the handler method that would have the
`@cancellable` flag directly, and so needs separate handling.

Signed-off-by: Sean Quah <seanq@element.io>
This commit is contained in:
Sean Quah 2022-05-11 12:25:13 +01:00 committed by GitHub
parent dffecade7d
commit 9d8e380d2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 191 additions and 2 deletions

1
changelog.d/12699.misc Normal file
View file

@ -0,0 +1 @@
Respect the `@cancellable` flag for `RestServlet`s and `BaseFederationServlet`s.

View file

@ -314,6 +314,9 @@ class HttpServer(Protocol):
If the regex contains groups these gets passed to the callback via If the regex contains groups these gets passed to the callback via
an unpacked tuple. an unpacked tuple.
The callback may be marked with the `@cancellable` decorator, which will
cause request processing to be cancelled when clients disconnect early.
Args: Args:
method: The HTTP method to listen to. method: The HTTP method to listen to.
path_patterns: The regex used to match requests. path_patterns: The regex used to match requests.
@ -544,6 +547,8 @@ class JsonResource(DirectServeJsonResource):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request) callback, servlet_classname, group_dict = self._get_handler_for_request(request)
request.is_render_cancellable = is_method_cancellable(callback)
# Make sure we have an appropriate name for this handler in prometheus # Make sure we have an appropriate name for this handler in prometheus
# (rather than the default of JsonResource). # (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname request.request_metrics.name = servlet_classname

View file

@ -0,0 +1,13 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,112 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import Dict, List, Tuple
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator
from synapse.http.server import JsonResource, cancellable
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
class CancellableFederationServlet(BaseFederationServlet):
PATH = "/sleep"
def __init__(
self,
hs: HomeServer,
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.clock = hs.get_clock()
@cancellable
async def on_GET(
self, origin: str, content: None, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
class BaseFederationServletCancellationTests(
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `BaseFederationServlet` cancellation."""
path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
def create_test_resource(self):
"""Overrides `HomeserverTestCase.create_test_resource`."""
resource = JsonResource(self.hs)
CancellableFederationServlet(
hs=self.hs,
authenticator=Authenticator(self.hs),
ratelimiter=self.hs.get_federation_ratelimiter(),
server_name=self.hs.hostname,
).register(resource)
return resource
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_signed_federation_request(
"GET", self.path, await_result=False
)
# Advance past all the rate limiting logic. If we disconnect too early, the
# request won't be processed.
self.pump()
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
)
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_signed_federation_request(
"POST",
self.path,
content={},
await_result=False,
)
# Advance past all the rate limiting logic. If we disconnect too early, the
# request won't be processed.
self.pump()
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)

View file

@ -12,16 +12,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Tuple
from unittest.mock import Mock from unittest.mock import Mock
from synapse.api.errors import SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import cancellable
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request, parse_json_object_from_request,
parse_json_value_from_request, parse_json_value_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
def make_request(content): def make_request(content):
@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
# Test not an object # Test not an object
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
parse_json_object_from_request(make_request(b'["foo"]')) parse_json_object_from_request(make_request(b'["foo"]'))
class CancellableRestServlet(RestServlet):
"""A `RestServlet` with a mix of cancellable and uncancellable handlers."""
PATTERNS = client_patterns("/sleep$")
def __init__(self, hs: HomeServer):
super().__init__()
self.clock = hs.get_clock()
@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.clock.sleep(1.0)
return HTTPStatus.OK, {"result": True}
class TestRestServletCancellation(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
"""Tests for `RestServlet` cancellation."""
servlets = [
lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
]
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
)
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)

View file

@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site, self.site,
method=method, method=method,
path=path, path=path,
content=content or "", content=content if content is not None else "",
shorthand=False, shorthand=False,
await_result=await_result, await_result=await_result,
custom_headers=custom_headers, custom_headers=custom_headers,