0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Use servlets for /key/ endpoints. (#14229)

To fix the response for unknown endpoints under that prefix.

See MSC3743.
This commit is contained in:
Patrick Cloke 2022-10-20 11:32:47 -04:00 committed by GitHub
parent da2c93d4b6
commit 755bfeee3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 83 deletions

1
changelog.d/14229.misc Normal file
View file

@ -0,0 +1 @@
Refactor `/key/` endpoints to use `RestServlet` classes.

View file

@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static" STATIC_PREFIX = "/_matrix/static"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_PREFIX = "/_matrix/key"
MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_R0_PREFIX = "/_matrix/media/r0"
MEDIA_V3_PREFIX = "/_matrix/media/v3" MEDIA_V3_PREFIX = "/_matrix/media/v3"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"

View file

@ -28,7 +28,7 @@ from synapse.api.urls import (
LEGACY_MEDIA_PREFIX, LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX, MEDIA_R0_PREFIX,
MEDIA_V3_PREFIX, MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_PREFIX,
) )
from synapse.app import _base from synapse.app import _base
from synapse.app._base import ( from synapse.app._base import (
@ -89,7 +89,7 @@ from synapse.rest.client.register import (
RegistrationTokenValidityRestServlet, RegistrationTokenValidityRestServlet,
) )
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyResource
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import well_known_resource from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer):
presence.register_servlets(self, resource) presence.register_servlets(self, resource)
resources.update({CLIENT_API_PREFIX: resource}) resources[CLIENT_API_PREFIX] = resource
resources.update(build_synapse_client_resource_tree(self)) resources.update(build_synapse_client_resource_tree(self))
resources.update({"/.well-known": well_known_resource(self)}) resources["/.well-known"] = well_known_resource(self)
elif name == "federation": elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) resources[FEDERATION_PREFIX] = TransportLayerServer(self)
elif name == "media": elif name == "media":
if self.config.media.can_load_media_repo: if self.config.media.can_load_media_repo:
media_repo = self.get_media_repository_resource() media_repo = self.get_media_repository_resource()
@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer):
# Only load the openid resource separately if federation resource # Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid # is not specified since federation resource includes openid
# resource. # resource.
resources.update( resources[FEDERATION_PREFIX] = TransportLayerServer(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"] self, servlet_groups=["openid"]
) )
}
)
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) resources[SERVER_KEY_PREFIX] = KeyResource(self)
if name == "replication": if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self) resources[REPLICATION_PREFIX] = ReplicationRestResource(self)

View file

@ -31,7 +31,7 @@ from synapse.api.urls import (
LEGACY_MEDIA_PREFIX, LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX, MEDIA_R0_PREFIX,
MEDIA_V3_PREFIX, MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_PREFIX,
STATIC_PREFIX, STATIC_PREFIX,
) )
from synapse.app import _base from synapse.app import _base
@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.rest import ClientRestResource from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyResource
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import well_known_resource from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -215,31 +215,23 @@ class SynapseHomeServer(HomeServer):
consent_resource: Resource = ConsentResource(self) consent_resource: Resource = ConsentResource(self)
if compress: if compress:
consent_resource = gz_wrap(consent_resource) consent_resource = gz_wrap(consent_resource)
resources.update({"/_matrix/consent": consent_resource}) resources["/_matrix/consent"] = consent_resource
if name == "federation": if name == "federation":
federation_resource: Resource = TransportLayerServer(self) federation_resource: Resource = TransportLayerServer(self)
if compress: if compress:
federation_resource = gz_wrap(federation_resource) federation_resource = gz_wrap(federation_resource)
resources.update({FEDERATION_PREFIX: federation_resource}) resources[FEDERATION_PREFIX] = federation_resource
if name == "openid": if name == "openid":
resources.update( resources[FEDERATION_PREFIX] = TransportLayerServer(
{
FEDERATION_PREFIX: TransportLayerServer(
self, servlet_groups=["openid"] self, servlet_groups=["openid"]
) )
}
)
if name in ["static", "client"]: if name in ["static", "client"]:
resources.update( resources[STATIC_PREFIX] = StaticResource(
{
STATIC_PREFIX: StaticResource(
os.path.join(os.path.dirname(synapse.__file__), "static") os.path.join(os.path.dirname(synapse.__file__), "static")
) )
}
)
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
if self.config.server.enable_media_repo: if self.config.server.enable_media_repo:
@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer):
) )
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) resources[SERVER_KEY_PREFIX] = KeyResource(self)
if name == "metrics" and self.config.metrics.enable_metrics: if name == "metrics" and self.config.metrics.enable_metrics:
metrics_resource: Resource = MetricsResource(RegistryProxy) metrics_resource: Resource = MetricsResource(RegistryProxy)

View file

@ -14,17 +14,20 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.resource import Resource from synapse.http.server import HttpServer, JsonResource
from synapse.rest.key.v2.local_key_resource import LocalKey
from .local_key_resource import LocalKey from synapse.rest.key.v2.remote_key_resource import RemoteKey
from .remote_key_resource import RemoteKey
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
class KeyApiV2Resource(Resource): class KeyResource(JsonResource):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) super().__init__(hs, canonical_json=True)
self.putChild(b"server", LocalKey(hs)) self.register_servlets(self, hs)
self.putChild(b"query", RemoteKey(hs))
@staticmethod
def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
LocalKey(hs).register(http_server)
RemoteKey(hs).register(http_server)

View file

@ -13,16 +13,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional import re
from typing import TYPE_CHECKING, Optional, Tuple
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -31,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LocalKey(Resource): class LocalKey(RestServlet):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL """HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server:: signature verification keys for this server::
@ -61,18 +60,17 @@ class LocalKey(Resource):
} }
""" """
isLeaf = True PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P<key_id>[^/]*))?$"),)
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec()) self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
def update_response_body(self, time_now_msec: int) -> None: def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key.key_refresh_interval refresh_interval = self.config.key.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval) self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object()) self.response_body = self.response_json_object()
def response_json_object(self) -> JsonDict: def response_json_object(self) -> JsonDict:
verify_keys = {} verify_keys = {}
@ -99,9 +97,11 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key) json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object return json_object
def render_GET(self, request: SynapseRequest) -> Optional[int]: def on_GET(
self, request: Request, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains. # Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now) self.update_response_body(time_now)
return respond_with_json_bytes(request, 200, self.response_body) return 200, self.response_body

View file

@ -13,15 +13,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Set import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError from twisted.web.server import Request
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import HttpServer
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import (
from synapse.http.site import SynapseRequest RestServlet,
parse_integer,
parse_json_object_from_request,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
@ -32,7 +37,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource): class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
} }
""" """
isLeaf = True
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource):
) )
self.config = hs.config self.config = hs.config
async def _async_render_GET(self, request: SynapseRequest) -> None: def register(self, http_server: HttpServer) -> None:
assert request.postpath is not None http_server.register_paths(
if len(request.postpath) == 1: "GET",
(server,) = request.postpath (
query: dict = {server.decode("ascii"): {}} re.compile(
elif len(request.postpath) == 2: "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
server, key_id = request.postpath ),
),
self.on_GET,
self.__class__.__name__,
)
http_server.register_paths(
"POST",
(re.compile("^/_matrix/key/v2/query$"),),
self.on_POST,
self.__class__.__name__,
)
async def on_GET(
self, request: Request, server: str, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
if server and key_id:
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {} arguments = {}
if minimum_valid_until_ts is not None: if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} query = {server: {key_id: arguments}}
else: else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) query = {server: {}}
await self.query_keys(request, query, query_remote_on_cache_miss=True) return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request: SynapseRequest) -> None: async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True) return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys( async def query_keys(
self, self, query: JsonDict, query_remote_on_cache_miss: bool = False
request: SynapseRequest, ) -> JsonDict:
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource):
for server_name, keys in cache_misses.items() for server_name, keys in cache_misses.items()
), ),
) )
await self.query_keys(request, query, query_remote_on_cache_miss=False) return await self.query_keys(query, query_remote_on_cache_miss=False)
else: else:
signed_keys = [] signed_keys = []
for key_json_raw in json_results: for key_json_raw in json_results:
@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json) signed_keys.append(key_json)
response = {"server_keys": signed_keys} return {"server_keys": signed_keys}
respond_with_json(request, 200, response, canonical_json=True)

View file

@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock()) @patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(

View file

@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict from synapse.types import JsonDict
@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource: def create_test_resource(self) -> Resource:
return create_resource_tree( return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
) )
def expect_outgoing_key_request( def expect_outgoing_key_request(