0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-05 23:44:07 +01:00

Allow expired accounts to logout (#7443)

This commit is contained in:
Andrew Morgan 2020-05-14 16:32:49 +01:00 committed by GitHub
parent 4734a7bbe4
commit 225c165087
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 140 additions and 22 deletions

1
changelog.d/7443.bugfix Normal file
View file

@ -0,0 +1 @@
Allow expired user accounts to log out their device sessions.

View file

@ -22,6 +22,7 @@ import pymacaroons
from netaddr import IPAddress from netaddr import IPAddress
from twisted.internet import defer from twisted.internet import defer
from twisted.web.server import Request
import synapse.logging.opentracing as opentracing import synapse.logging.opentracing as opentracing
import synapse.types import synapse.types
@ -162,19 +163,25 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req( def get_user_by_req(
self, request, allow_guest=False, rights="access", allow_expired=False self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
): ):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
request - An HTTP request with an access_token query parameter. request: An HTTP request with an access_token query parameter.
allow_expired - Whether to allow the request through even if the account is allow_guest: If False, will raise an AuthError if the user making the
expired. If true, Synapse will still require an access token to be request is a guest.
provided but won't check if the account it belongs to has expired. This rights: The operation being performed; the access token must allow this
works thanks to /login delivering access tokens regardless of accounts' allow_expired: If True, allow the request through even if the account
expiration. is expired, or session token lifetime has ended. Note that
/login will deliver access tokens regardless of expiration.
Returns: Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object defer.Deferred: resolves to a `synapse.types.Requester` object
Raises: Raises:
InvalidClientCredentialsError if no user by that token exists or the token InvalidClientCredentialsError if no user by that token exists or the token
is invalid. is invalid.
@ -205,7 +212,9 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service) return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -280,13 +289,17 @@ class Auth(object):
return user_id, app_service return user_id, app_service
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
""" Validate access token and get user_id from it """ Validate access token and get user_id from it
Args: Args:
token (str): The access token to get the user by. token: The access token to get the user by
rights (str): The operation being performed; the access token must rights: The operation being performed; the access token must
allow this. allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns: Returns:
Deferred[dict]: dict that includes: Deferred[dict]: dict that includes:
`user` (UserID) `user` (UserID)
@ -294,8 +307,10 @@ class Auth(object):
`token_id` (int|None): access token id. May be None if guest `token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token `device_id` (str|None): device corresponding to access token
Raises: Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
InvalidClientCredentialsError if no user by that token exists or the token InvalidClientCredentialsError if no user by that token exists or the token
is invalid. is invalid
""" """
if rights == "access": if rights == "access":
@ -304,7 +319,8 @@ class Auth(object):
if r: if r:
valid_until_ms = r["valid_until_ms"] valid_until_ms = r["valid_until_ms"]
if ( if (
valid_until_ms is not None not allow_expired
and valid_until_ms is not None
and valid_until_ms < self.clock.time_msec() and valid_until_ms < self.clock.time_msec()
): ):
# there was a valid access token, but it has expired. # there was a valid access token, but it has expired.
@ -575,7 +591,7 @@ class Auth(object):
return user_level >= send_level return user_level >= send_level
@staticmethod @staticmethod
def has_access_token(request): def has_access_token(request: Request):
"""Checks if the request has an access_token. """Checks if the request has an access_token.
Returns: Returns:
@ -586,7 +602,7 @@ class Auth(object):
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod @staticmethod
def get_access_token_from_request(request): def get_access_token_from_request(request: Request):
"""Extracts the access_token from the request. """Extracts the access_token from the request.
Args: Args:

View file

@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet):
return 200, {} return 200, {}
async def on_POST(self, request): async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # The access token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = self.auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
await self._auth_handler.delete_access_token(access_token) await self._auth_handler.delete_access_token(access_token)
@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet):
return 200, {} return 200, {}
async def on_POST(self, request): async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices # first delete all of the user's devices

View file

@ -4,7 +4,7 @@ import urllib.parse
from mock import Mock from mock import Mock
import synapse.rest.admin import synapse.rest.admin
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
@ -20,6 +20,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
logout.register_servlets,
devices.register_servlets, devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
] ]
@ -256,6 +257,72 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEquals(channel.code, 200, channel.result) self.assertEquals(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self):
self.register_user("kermit", "monkey")
# log in as normal
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard logout this session
request, channel = self.make_request(
b"POST", "/logout", access_token=access_token
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
self.register_user("kermit", "monkey")
# log in as normal
access_token = self.login("kermit", "monkey")
# we should now be able to make requests with the access token
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
request, channel = self.make_request(
b"GET", TEST_URL, access_token=access_token
)
self.render(request)
self.assertEquals(channel.code, 401, channel.result)
self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEquals(channel.json_body["soft_logout"], True)
# Now try to hard log out all of the user's sessions
request, channel = self.make_request(
b"POST", "/logout/all", access_token=access_token
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
class CASTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase):

View file

@ -25,7 +25,7 @@ import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import account, account_validity, register, sync from synapse.rest.client.v2_alpha import account, account_validity, register, sync
from tests import unittest from tests import unittest
@ -313,6 +313,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
sync.register_servlets, sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets, account_validity.register_servlets,
] ]
@ -405,6 +406,39 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
) )
def test_logging_out_expired_user(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
url = "/_matrix/client/unstable/admin/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Try to log the user out
request, channel = self.make_request(b"POST", "/logout", access_token=tok)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):