0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-16 06:51:46 +01:00

Revert MSC3861 introspection cache, admin impersonation and account lock (#16258)

This commit is contained in:
Quentin Gliech 2023-09-06 16:19:51 +02:00 committed by GitHub
parent 1cd0715a0f
commit 1940d990a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 31 additions and 391 deletions

1
changelog.d/16258.bugfix Normal file
View file

@ -0,0 +1 @@
Revert MSC3861 introspection cache, admin impersonation and account lock.

View file

@ -28,7 +28,6 @@ from twisted.web.http_headers import Headers
from synapse.api.auth.base import BaseAuth from synapse.api.auth.base import BaseAuth
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes,
HttpResponseException, HttpResponseException,
InvalidClientTokenError, InvalidClientTokenError,
OAuthInsufficientScopeError, OAuthInsufficientScopeError,
@ -40,7 +39,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -109,20 +107,13 @@ class MSC3861DelegatedAuth(BaseAuth):
assert self._config.client_id, "No client_id provided" assert self._config.client_id, "No client_id provided"
assert auth_method is not None, "Invalid client_auth_method provided" assert auth_method is not None, "Invalid client_auth_method provided"
self._clock = hs.get_clock()
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._hostname = hs.hostname self._hostname = hs.hostname
self._admin_token = self._config.admin_token self._admin_token = self._config.admin_token
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
self._clock = hs.get_clock()
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
cache_name="introspection_token_cache",
clock=self._clock,
max_len=10000,
expiry_ms=5 * 60 * 1000,
)
if isinstance(auth_method, PrivateKeyJWTWithKid): if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method # Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided" assert self._config.jwk, "No JWK provided"
@ -161,20 +152,6 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns: Returns:
The introspection response The introspection response
""" """
# check the cache before doing a request
introspection_token = self._token_cache.get(token, None)
if introspection_token:
# check the expiration field of the token (if it exists)
exp = introspection_token.get("exp", None)
if exp:
time_now = self._clock.time()
expired = time_now > exp
if not expired:
return introspection_token
else:
return introspection_token
metadata = await self._issuer_metadata.get() metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint") introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = { raw_headers: Dict[str, str] = {
@ -188,10 +165,7 @@ class MSC3861DelegatedAuth(BaseAuth):
# Fill the body/headers with credentials # Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare( uri, raw_headers, body = self._client_auth.prepare(
method="POST", method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
uri=introspection_endpoint,
headers=raw_headers,
body=body,
) )
headers = Headers({k: [v] for (k, v) in raw_headers.items()}) headers = Headers({k: [v] for (k, v) in raw_headers.items()})
@ -233,20 +207,10 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response." "The introspection endpoint returned an invalid JSON response."
) )
expiration = resp.get("exp", None) return IntrospectionToken(**resp)
if expiration:
if self._clock.time() > expiration:
raise InvalidClientTokenError("Token is expired.")
introspection_token = IntrospectionToken(**resp)
# add token to cache
self._token_cache[token] = introspection_token
return introspection_token
async def is_server_admin(self, requester: Requester) -> bool: async def is_server_admin(self, requester: Requester) -> bool:
return SCOPE_SYNAPSE_ADMIN in requester.scope return "urn:synapse:admin:*" in requester.scope
async def get_user_by_req( async def get_user_by_req(
self, self,
@ -263,36 +227,6 @@ class MSC3861DelegatedAuth(BaseAuth):
# so that we don't provision the user if they don't have enough permission: # so that we don't provision the user if they don't have enough permission:
requester = await self.get_user_by_access_token(access_token, allow_expired) requester = await self.get_user_by_access_token(access_token, allow_expired)
# Allow impersonation by an admin user using `_oidc_admin_impersonate_user_id` query parameter
if request.args is not None:
user_id_params = request.args.get(b"_oidc_admin_impersonate_user_id")
if user_id_params:
if await self.is_server_admin(requester):
user_id_str = user_id_params[0].decode("ascii")
impersonated_user_id = UserID.from_string(user_id_str)
logging.info(f"Admin impersonation of user {user_id_str}")
requester = create_requester(
user_id=impersonated_user_id,
scope=[SCOPE_MATRIX_API],
authenticated_entity=requester.user.to_string(),
)
else:
raise AuthError(
401,
"Impersonation not possible by a non admin user",
)
# Deny the request if the user account is locked.
if not allow_locked and await self.store.get_user_locked_status(
requester.user.to_string()
):
raise AuthError(
401,
"User account has been locked",
errcode=Codes.USER_LOCKED,
additional_fields={"soft_logout": True},
)
if not allow_guest and requester.is_guest: if not allow_guest and requester.is_guest:
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API]) raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
@ -309,14 +243,14 @@ class MSC3861DelegatedAuth(BaseAuth):
# XXX: This is a temporary solution so that the admin API can be called by # XXX: This is a temporary solution so that the admin API can be called by
# the OIDC provider. This will be removed once we have OIDC client # the OIDC provider. This will be removed once we have OIDC client
# credentials grant support in matrix-authentication-service. # credentials grant support in matrix-authentication-service.
logging.info("Admin token used") logging.info("Admin toked used")
# XXX: that user doesn't exist and won't be provisioned. # XXX: that user doesn't exist and won't be provisioned.
# This is mostly fine for admin calls, but we should also think about doing # This is mostly fine for admin calls, but we should also think about doing
# requesters without a user_id. # requesters without a user_id.
admin_user = UserID("__oidc_admin", self._hostname) admin_user = UserID("__oidc_admin", self._hostname)
return create_requester( return create_requester(
user_id=admin_user, user_id=admin_user,
scope=[SCOPE_SYNAPSE_ADMIN], scope=["urn:synapse:admin:*"],
) )
try: try:
@ -438,16 +372,3 @@ class MSC3861DelegatedAuth(BaseAuth):
scope=scope, scope=scope,
is_guest=(has_guest_scope and not has_user_scope), is_guest=(has_guest_scope and not has_user_scope),
) )
def invalidate_cached_tokens(self, keys: List[str]) -> None:
"""
Invalidate the entry(s) in the introspection token cache corresponding to the given key
"""
for key in keys:
self._token_cache.invalidate(key)
def invalidate_token_cache(self) -> None:
"""
Invalidate the entire token cache.
"""
self._token_cache.invalidate_all()

View file

@ -28,7 +28,6 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
AccountDataStream, AccountDataStream,
CachesStream,
DeviceListsStream, DeviceListsStream,
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
@ -76,7 +75,6 @@ class ReplicationDataHandler:
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._typing_handler = hs.get_typing_handler() self._typing_handler = hs.get_typing_handler()
self._state_storage_controller = hs.get_storage_controllers().state self._state_storage_controller = hs.get_storage_controllers().state
self.auth = hs.get_auth()
self._notify_pushers = hs.config.worker.start_pushers self._notify_pushers = hs.config.worker.start_pushers
self._pusher_pool = hs.get_pusherpool() self._pusher_pool = hs.get_pusherpool()
@ -224,16 +222,6 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated( self._state_storage_controller.notify_event_un_partial_stated(
row.event_id row.event_id
) )
# invalidate the introspection token cache
elif stream_name == CachesStream.NAME:
for row in rows:
if row.cache_func == "introspection_token_invalidation":
if row.keys[0] is None:
# invalidate the whole cache
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth.invalidate_token_cache() # type: ignore[attr-defined]
else:
self.auth.invalidate_cached_tokens(row.keys) # type: ignore[attr-defined]
await self._presence_handler.process_replication_rows( await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows stream_name, instance_name, token, rows

View file

@ -47,7 +47,6 @@ from synapse.rest.admin.federation import (
ListDestinationsRestServlet, ListDestinationsRestServlet,
) )
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import ( from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet, ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet, NewRegistrationTokenRestServlet,
@ -298,8 +297,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server)
if hs.config.experimental.msc3861.enabled:
OIDCTokenRevocationRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(

View file

@ -1,55 +0,0 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Tuple
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
if TYPE_CHECKING:
from synapse.server import HomeServer
class OIDCTokenRevocationRestServlet(RestServlet):
"""
Delete a given token introspection response - identified by the `jti` field - from the
introspection token cache when a token is revoked at the authorizing server
"""
PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__()
auth = hs.get_auth()
# If this endpoint is loaded then we must have enabled delegated auth.
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
assert isinstance(auth, MSC3861DelegatedAuth)
self.auth = auth
self.store = hs.get_datastores().main
async def on_DELETE(
self, request: SynapseRequest, token_id: str
) -> Tuple[HTTPStatus, Dict]:
await assert_requester_is_admin(self.auth, request)
self.auth._token_cache.invalidate(token_id)
# make sure we invalidate the cache on any workers
await self.store.stream_introspection_token_invalidation((token_id,))
return HTTPStatus.OK, {}

View file

@ -584,19 +584,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else: else:
return 0 return 0
async def stream_introspection_token_invalidation(
self, key: Tuple[Optional[str]]
) -> None:
"""
Stream an invalidation request for the introspection token cache to workers
Args:
key: token_id of the introspection token to remove from the cache
"""
await self.send_invalidation_to_replication(
"introspection_token_invalidation", key
)
@wrap_as_background_process("clean_up_old_cache_invalidations") @wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None: async def _clean_up_cache_invalidation_wrapper(self) -> None:
""" """

View file

@ -33,7 +33,6 @@ from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
get_active_span_text_map, get_active_span_text_map,
set_tag, set_tag,
@ -1664,7 +1663,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache: LruCache[ self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True] Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000) ] = LruCache(cache_name="device_id_exists", max_size=10000)
self.config: HomeServerConfig = hs.config
async def store_device( async def store_device(
self, self,
@ -1786,13 +1784,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids: for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))
# TODO: don't nuke the entire cache once there is a way to associate
# device_id -> introspection_token
if self.config.experimental.msc3861.enabled:
# mypy ignore - the token cache is defined on MSC3861DelegatedAuth
self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
await self.stream_introspection_token_invalidation((None,))
async def update_device( async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None: ) -> None:

View file

@ -140,20 +140,6 @@ class ExpiringCache(Generic[KT, VT]):
return value.value return value.value
def invalidate(self, key: KT) -> None:
"""
Remove the given key from the cache.
"""
value = self._cache.pop(key, None)
if value:
if self.iterable:
self.metrics.inc_evictions(
EvictionReason.invalidation, len(value.value)
)
else:
self.metrics.inc_evictions(EvictionReason.invalidation)
def __contains__(self, key: KT) -> bool: def __contains__(self, key: KT) -> bool:
return key in self._cache return key in self._cache
@ -207,14 +193,6 @@ class ExpiringCache(Generic[KT, VT]):
len(self), len(self),
) )
def invalidate_all(self) -> None:
"""
Remove all items from the cache.
"""
keys = set(self._cache.keys())
for key in keys:
self._cache.pop(key)
def __len__(self) -> int: def __len__(self) -> int:
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in self._cache.values()) return sum(len(entry.value) for entry in self._cache.values())

View file

@ -14,7 +14,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, Union from typing import Any, Dict, Union
from unittest.mock import ANY, AsyncMock, Mock from unittest.mock import ANY, Mock
from urllib.parse import parse_qs from urllib.parse import parse_qs
from signedjson.key import ( from signedjson.key import (
@ -122,6 +122,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"client_id": CLIENT_ID, "client_id": CLIENT_ID,
"client_auth_method": "client_secret_post", "client_auth_method": "client_secret_post",
"client_secret": CLIENT_SECRET, "client_secret": CLIENT_SECRET,
"admin_token": "admin_token_value",
} }
} }
return config return config
@ -340,41 +341,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
get_awaitable_result(self.auth.is_server_admin(requester)), False get_awaitable_result(self.auth.is_server_admin(requester)), False
) )
def test_active_user_admin_impersonation(self) -> None:
"""The handler should return a requester with normal user rights
and an user ID matching the one specified in query param `user_id`"""
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
)
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
impersonated_user_id = f"@{USERNAME}:{SERVER_NAME}"
request.args[b"_oidc_admin_impersonate_user_id"] = [
impersonated_user_id.encode("ascii")
]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
)
self._assertParams()
self.assertEqual(requester.user.to_string(), impersonated_user_id)
self.assertEqual(requester.is_guest, False)
self.assertEqual(requester.device_id, None)
self.assertEqual(
get_awaitable_result(self.auth.is_server_admin(requester)), False
)
def test_active_user_with_device(self) -> None: def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID.""" """The handler should return a requester with normal user rights and a device ID."""
@ -526,100 +492,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503) self.assertEqual(error.value.code, 503)
def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
)
)
# first call should cache response
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
# for regular auth code via the config
self.get_success(
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], access_token)
# there's been one http request
self.http_client.request.assert_called_once()
# second call should pull from cache, there should still be only one http request
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.http_client.request.assert_called_once()
self.assertEqual(token["jti"], access_token)
# advance past five minutes and check that cache expired - there should be more than one http call now
self.reactor.advance(360)
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.assertEqual(self.http_client.request.call_count, 2)
self.assertEqual(token_2["jti"], access_token)
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={
"active": "true",
"scope": "guest",
"jti": "stale",
"exp": self.clock.time() + 100,
},
)
)
self.get_success(
self.auth._introspect_token("stale") # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], "stale")
self.assertEqual(self.http_client.request.call_count, 1)
# advance the reactor past the token expiry but less than the cache expiry
self.reactor.advance(120)
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
# check that the next call causes another http request (which will fail because the token is technically expired
# but the important thing is we discard the token from the cache and try the network)
self.get_failure(
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
)
self.assertEqual(self.http_client.request.call_count, 2)
def test_revocation_endpoint(self) -> None:
# mock introspection response and then admin verification response
self.http_client.request = AsyncMock(
side_effect=[
FakeResponse.json(
code=200, payload={"active": True, "jti": "open_sesame"}
),
FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
),
]
)
# cache a token to delete
introspection_token = self.get_success(
self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
)
self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]
# delete the revoked token
introspection_token_id = "open_sesame"
url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
channel = self.make_request("DELETE", url, access_token="mockAccessToken")
self.assertEqual(channel.code, 200)
self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test. # We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id) master_signing_key = generate_signing_key(device_id)
@ -791,3 +663,25 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin") self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin")
self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin") self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin")
self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity") self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity")
def test_admin_token(self) -> None:
"""The handler should return a requester with admin rights when admin_token is used."""
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(code=200, payload={"active": False}),
)
request = Mock(args={})
request.args[b"access_token"] = [b"admin_token_value"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.assertEqual(
requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME)
)
self.assertEqual(requester.is_guest, False)
self.assertEqual(requester.device_id, None)
self.assertEqual(
get_awaitable_result(self.auth.is_server_admin(requester)), True
)
# There should be no call to the introspection endpoint
self.http_client.request.assert_not_called()

View file

@ -1,62 +0,0 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
import synapse.rest.admin._base
from tests.replication._base import BaseMultiWorkerStreamTestCase
class IntrospectionTokenCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [synapse.rest.admin.register_servlets]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["disable_registration"] = True
config["experimental_features"] = {
"msc3861": {
"enabled": True,
"issuer": "some_dude",
"client_id": "ID",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
}
}
return config
def test_stream_introspection_token_invalidation(self) -> None:
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
auth = worker_hs.get_auth()
store = self.hs.get_datastores().main
# add a token to the cache on the worker
auth._token_cache["open_sesame"] = "intro_token" # type: ignore[attr-defined]
# stream the invalidation from the master
self.get_success(
store.stream_introspection_token_invalidation(("open_sesame",))
)
# check that the cache on the worker was invalidated
self.assertEqual(auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
# test invalidating whole cache
for i in range(0, 5):
auth._token_cache[f"open_sesame_{i}"] = f"intro_token_{i}" # type: ignore[attr-defined]
self.assertEqual(len(auth._token_cache), 5) # type: ignore[attr-defined]
self.get_success(store.stream_introspection_token_invalidation((None,)))
self.assertEqual(len(auth._token_cache), 0) # type: ignore[attr-defined]