0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 06:51:46 +01:00

Add an admin endpoint to allow authorizing server to signal token revocations (#16125)

This commit is contained in:
Shay 2023-08-22 07:15:34 -07:00 committed by GitHub
parent 8aa5479986
commit 69048f7b48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 223 additions and 1 deletions

1
changelog.d/16125.misc Normal file
View file

@ -0,0 +1 @@
Add an admin endpoint to allow authorizing server to signal token revocations.

View file

@ -438,3 +438,16 @@ class MSC3861DelegatedAuth(BaseAuth):
scope=scope, scope=scope,
is_guest=(has_guest_scope and not has_user_scope), is_guest=(has_guest_scope and not has_user_scope),
) )
def invalidate_cached_tokens(self, keys: List[str]) -> None:
"""
Invalidate the entry(s) in the introspection token cache corresponding to the given key
"""
for key in keys:
self._token_cache.invalidate(key)
def invalidate_token_cache(self) -> None:
"""
Invalidate the entire token cache.
"""
self._token_cache.invalidate_all()

View file

@ -26,6 +26,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
AccountDataStream, AccountDataStream,
CachesStream,
DeviceListsStream, DeviceListsStream,
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
@ -73,6 +74,7 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler() self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
self.auth = hs.get_auth()
self._notify_pushers = hs.config.worker.start_pushers self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool() self._pusher_pool = hs.get_pusherpool()
@ -218,6 +220,16 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated( self._state_storage_controller.notify_event_un_partial_stated(
row.event_id row.event_id
) )
# invalidate the introspection token cache
elif stream_name == CachesStream.NAME:
for row in rows:
if row.cache_func == "introspection_token_invalidation":
if row.keys[0] is None:
# invalidate the whole cache
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth.invalidate_token_cache() # type: ignore[attr-defined]
else:
self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]
await self._presence_handler.process_replication_rows( await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows stream_name, instance_name, token, rows

View file

@ -47,6 +47,7 @@ from synapse.rest.admin.federation import (
ListDestinationsRestServlet, ListDestinationsRestServlet,
) )
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import ( from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet, ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet, NewRegistrationTokenRestServlet,
@ -297,6 +298,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server)
if hs.config.experimental.msc3861.enabled:
OIDCTokenRevocationRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(

View file

@ -0,0 +1,55 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Tuple
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
if TYPE_CHECKING:
from synapse.server import HomeServer
class OIDCTokenRevocationRestServlet(RestServlet):
"""
Delete a given token introspection response - identified by the `jti` field - from the
introspection token cache when a token is revoked at the authorizing server
"""
PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__()
auth = hs.get_auth()
# If this endpoint is loaded then we must have enabled delegated auth.
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
assert isinstance(auth, MSC3861DelegatedAuth)
self.auth = auth
self.store = hs.get_datastores().main
async def on_DELETE(
self, request: SynapseRequest, token_id: str
) -> Tuple[HTTPStatus, Dict]:
await assert_requester_is_admin(self.auth, request)
self.auth._token_cache.invalidate(token_id)
# make sure we invalidate the cache on any workers
await self.store.stream_introspection_token_invalidation((token_id,))
return HTTPStatus.OK, {}

View file

@ -584,6 +584,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
return 0 return 0
async def stream_introspection_token_invalidation(
self, key: Tuple[Optional[str]]
) -> None:
"""
Stream an invalidation request for the introspection token cache to workers
Args:
key: token_id of the introspection token to remove from the cache
"""
await self.send_invalidation_to_replication(
"introspection_token_invalidation", key
)
@wrap_as_background_process("clean_up_old_cache_invalidations") @wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None: async def _clean_up_cache_invalidation_wrapper(self) -> None:
""" """

View file

@ -33,6 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
get_active_span_text_map, get_active_span_text_map,
set_tag, set_tag,
@ -1663,6 +1664,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache: LruCache[ self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True] Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000) ] = LruCache(cache_name="device_id_exists", max_size=10000)
self.config: HomeServerConfig = hs.config
async def store_device( async def store_device(
self, self,
@ -1784,6 +1786,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids: for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))
# TODO: don't nuke the entire cache once there is a way to associate
# device_id -> introspection_token
if self.config.experimental.msc3861.enabled:
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
await self.stream_introspection_token_invalidation((None,))
async def update_device( async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None: ) -> None:

View file

@ -140,6 +140,20 @@ class ExpiringCache(Generic[KT, VT]):
return value.value return value.value
def invalidate(self, key: KT) -> None:
"""
Remove the given key from the cache.
"""
value = self._cache.pop(key, None)
if value:
if self.iterable:
self.metrics.inc_evictions(
EvictionReason.invalidation, len(value.value)
)
else:
self.metrics.inc_evictions(EvictionReason.invalidation)
def __contains__(self, key: KT) -> bool: def __contains__(self, key: KT) -> bool:
return key in self._cache return key in self._cache
@ -193,6 +207,14 @@ class ExpiringCache(Generic[KT, VT]):
len(self), len(self),
) )
def invalidate_all(self) -> None:
"""
Remove all items from the cache.
"""
keys = set(self._cache.keys())
for key in keys:
self._cache.pop(key)
def __len__(self) -> int: def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) return sum(len(entry.value) for entry in self._cache.values())

View file

@ -14,7 +14,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, Union from typing import Any, Dict, Union
from unittest.mock import ANY, Mock from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs from urllib.parse import parse_qs
from signedjson.key import ( from signedjson.key import (
@ -588,6 +588,38 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
) )
self.assertEqual(self.http_client.request.call_count, 2) self.assertEqual(self.http_client.request.call_count, 2)
def test_revocation_endpoint(self) -> None:
# mock introspection response and then admin verification response
self.http_client.request = AsyncMock(
side_effect=[
FakeResponse.json(
code=200, payload={"active": True, "jti": "open_sesame"}
),
FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
),
]
)
# cache a token to delete
introspection_token = self.get_success(
self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
)
self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]
# delete the revoked token
introspection_token_id = "open_sesame"
url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
channel = self.make_request("DELETE", url, access_token="mockAccessToken")
self.assertEqual(channel.code, 200)
self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test. # We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id) master_signing_key = generate_signing_key(device_id)

View file

@ -0,0 +1,62 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import synapse.rest.admin._base
from tests.replication._base import BaseMultiWorkerStreamTestCase
class IntrospectionTokenCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [synapse.rest.admin.register_servlets]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["disable_registration"] = True
config["experimental_features"] = {
"msc3861": {
"enabled": True,
"issuer": "some_dude",
"client_id": "ID",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
}
}
return config
def test_stream_introspection_token_invalidation(self) -> None:
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
auth = worker_hs.get_auth()
store = self.hs.get_datastores().main
# add a token to the cache on the worker
auth._token_cache["open_sesame"] = "intro_token" # type: ignore[attr-defined]
# stream the invalidation from the master
self.get_success(
store.stream_introspection_token_invalidation(("open_sesame",))
)
# check that the cache on the worker was invalidated
self.assertEqual(auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
# test invalidating whole cache
for i in range(0, 5):
auth._token_cache[f"open_sesame_{i}"] = f"intro_token_{i}" # type: ignore[attr-defined]
self.assertEqual(len(auth._token_cache), 5) # type: ignore[attr-defined]
self.get_success(store.stream_introspection_token_invalidation((None,)))
self.assertEqual(len(auth._token_cache), 0) # type: ignore[attr-defined]