forked from MirrorHub/synapse
UIA: offer only available auth flows
During user-interactive auth, do not offer password auth to users with no password, nor SSO auth to users with no SSO. Fixes #7559.
This commit is contained in:
parent
76469898ee
commit
0bac276890
6 changed files with 278 additions and 33 deletions
|
@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
|
|||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._sso_enabled = (
|
||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
||||
)
|
||||
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
|
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = []
|
||||
if hs.config.password_localdb_enabled:
|
||||
if self._password_localdb_enabled:
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
|
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._supported_login_types = login_types
|
||||
|
||||
# Login types and UI Auth types have a heavy overlap, but are not
|
||||
# necessarily identical. Login types have SSO (and other login types)
|
||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||
ui_auth_types = login_types.copy()
|
||||
if self._sso_enabled:
|
||||
ui_auth_types.append(LoginType.SSO)
|
||||
self._supported_ui_auth_types = ui_auth_types
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
|
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
|
|||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||
requester.user
|
||||
)
|
||||
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||
|
||||
try:
|
||||
result, params, session_id = await self.check_ui_auth(
|
||||
|
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
|
|||
raise
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_ui_auth_types:
|
||||
for login_type in supported_ui_auth_types:
|
||||
if login_type not in result:
|
||||
continue
|
||||
|
||||
|
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
return params, session_id
|
||||
|
||||
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||
"""Get a list of the authentication types this user can use
|
||||
"""
|
||||
|
||||
ui_auth_types = set()
|
||||
|
||||
# if the HS supports password auth, and the user has a non-null password, we
|
||||
# support password auth
|
||||
if self._password_localdb_enabled and self._password_enabled:
|
||||
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
|
||||
if lookupres:
|
||||
_, password_hash = lookupres
|
||||
if password_hash:
|
||||
ui_auth_types.add(LoginType.PASSWORD)
|
||||
|
||||
# also allow auth from password providers
|
||||
for provider in self.password_providers:
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||
continue
|
||||
ui_auth_types.add(t)
|
||||
|
||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||
# from sso to mxid.
|
||||
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
||||
if await self.store.get_external_ids_by_user(user.to_string()):
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
||||
# so always offer that if it's available.
|
||||
if self.hs.config.cas.cas_enabled:
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
return ui_auth_types
|
||||
|
||||
def get_enabled_auth_types(self):
|
||||
"""Return the enabled user-interactive authentication types
|
||||
|
||||
|
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
|
|||
if result:
|
||||
return result
|
||||
|
||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||
known_login_type = True
|
||||
|
||||
# we've already checked that there is a (valid) password field
|
||||
|
|
|
@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="get_user_by_external_id",
|
||||
)
|
||||
|
||||
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
|
||||
"""Look up external ids for the given user
|
||||
|
||||
Args:
|
||||
mxid: the MXID to be looked up
|
||||
|
||||
Returns:
|
||||
Tuples of (auth_provider, external_id)
|
||||
"""
|
||||
res = await self.db_pool.simple_select_list(
|
||||
table="user_external_ids",
|
||||
keyvalues={"user_id": mxid},
|
||||
retcols=("auth_provider", "external_id"),
|
||||
desc="get_external_ids_by_user",
|
||||
)
|
||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||
|
||||
async def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
|
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"user_external_ids_user_id_idx",
|
||||
index_name="user_external_ids_user_id_idx",
|
||||
table="user_external_ids",
|
||||
columns=["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
for each of them.
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5825, 'user_external_ids_user_id_idx', '{}');
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,17 +17,23 @@
|
|||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mock import patch
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Site
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.server import FakeSite, make_request
|
||||
from tests.test_utils import FakeResponse
|
||||
|
||||
|
||||
@attr.s
|
||||
|
@ -344,3 +350,111 @@ class RestHelper:
|
|||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||
"""Log in (as a new user) via OIDC
|
||||
|
||||
Returns the result of the final token login.
|
||||
|
||||
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||
"public_base_url".
|
||||
|
||||
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||
the normal places.
|
||||
"""
|
||||
client_redirect_url = "https://x"
|
||||
|
||||
# first hit the redirect url (which will issue a cookie and state)
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
|
||||
)
|
||||
# that will redirect to the OIDC IdP, but we skip that and go straight
|
||||
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||
# param that synapse passes to the IdP via query params, and the cookie that
|
||||
# synapse passes to the client.
|
||||
assert channel.code == 302
|
||||
oauth_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
|
||||
redirect_uri = "%s?%s" % (
|
||||
urllib.parse.urlparse(params["redirect_uri"][0]).path,
|
||||
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
|
||||
)
|
||||
cookies = {}
|
||||
for h in channel.headers.getRawHeaders("Set-Cookie"):
|
||||
parts = h.split(";")
|
||||
k, v = parts[0].split("=", maxsplit=1)
|
||||
cookies[k] = v
|
||||
|
||||
# before we hit the callback uri, stub out some methods in the http client so
|
||||
# that we don't have to handle full HTTPS requests.
|
||||
|
||||
# (expected url, json response) pairs, in the order we expect them.
|
||||
expected_requests = [
|
||||
# first we get a hit to the token endpoint, which we tell to return
|
||||
# a dummy OIDC access token
|
||||
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||
# and then one to the user_info endpoint, which returns our remote user id.
|
||||
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||
]
|
||||
|
||||
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||
(expected_uri, resp_obj) = expected_requests.pop(0)
|
||||
assert uri == expected_uri
|
||||
resp = FakeResponse(
|
||||
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
|
||||
)
|
||||
return resp
|
||||
|
||||
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
|
||||
# now hit the callback URI with the right params and a made-up code
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
redirect_uri,
|
||||
custom_headers=[
|
||||
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
|
||||
],
|
||||
)
|
||||
|
||||
# expect a confirmation page
|
||||
assert channel.code == 200
|
||||
|
||||
# fish the matrix login token out of the body of the confirmation page
|
||||
m = re.search(
|
||||
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||
channel.result["body"].decode("utf-8"),
|
||||
)
|
||||
assert m
|
||||
login_token = m.group(1)
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token and device id.
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
assert channel.code == 200
|
||||
return channel.json_body
|
||||
|
||||
|
||||
# an 'oidc_config' suitable for login_with_oidc.
|
||||
TEST_OIDC_CONFIG = {
|
||||
"enabled": True,
|
||||
"discover": False,
|
||||
"issuer": "https://issuer.test",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"scopes": ["profile"],
|
||||
"authorization_endpoint": "https://z",
|
||||
"token_endpoint": "https://issuer.test/token",
|
||||
"userinfo_endpoint": "https://issuer.test/userinfo",
|
||||
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from twisted.internet.defer import succeed
|
||||
|
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
||||
from synapse.types import JsonDict
|
||||
from synapse.rest.oidc import OIDCResource
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
|
||||
from tests.server import FakeChannel
|
||||
|
||||
|
||||
|
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
register.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
# we enable OIDC as a way of testing SSO flows
|
||||
oidc_config = {}
|
||||
oidc_config.update(TEST_OIDC_CONFIG)
|
||||
oidc_config["allow_existing_users"] = True
|
||||
|
||||
config["oidc_config"] = oidc_config
|
||||
config["public_baseurl"] = "https://synapse.test"
|
||||
return config
|
||||
|
||||
def create_resource_dict(self):
|
||||
resource_dict = super().create_resource_dict()
|
||||
# mount the OIDC resource at /_synapse/oidc
|
||||
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||
return resource_dict
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user_pass = "pass"
|
||||
self.user = self.register_user("test", self.user_pass)
|
||||
self.user_tok = self.login("test", self.user_pass)
|
||||
|
||||
def get_device_ids(self) -> List[str]:
|
||||
def get_device_ids(self, access_token: str) -> List[str]:
|
||||
# Get the list of devices so one can be deleted.
|
||||
request, channel = self.make_request(
|
||||
"GET", "devices", access_token=self.user_tok,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Get the ID of the device.
|
||||
self.assertEqual(request.code, 200)
|
||||
_, channel = self.make_request("GET", "devices", access_token=access_token,)
|
||||
self.assertEqual(channel.code, 200)
|
||||
return [d["device_id"] for d in channel.json_body["devices"]]
|
||||
|
||||
def delete_device(
|
||||
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
|
||||
self,
|
||||
access_token: str,
|
||||
device: str,
|
||||
expected_response: int,
|
||||
body: Union[bytes, JsonDict] = b"",
|
||||
) -> FakeChannel:
|
||||
"""Delete an individual device."""
|
||||
request, channel = self.make_request(
|
||||
"DELETE", "devices/" + device, body, access_token=self.user_tok
|
||||
"DELETE", "devices/" + device, body, access_token=access_token,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Ensure the response is sane.
|
||||
|
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
"""
|
||||
Test user interactive authentication outside of registration.
|
||||
"""
|
||||
device_id = self.get_device_ids()[0]
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
|
||||
# Attempt to delete this device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_id, 401)
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
|
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
|
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
UIA - check that still works.
|
||||
"""
|
||||
|
||||
device_id = self.get_device_ids()[0]
|
||||
channel = self.delete_device(device_id, 401)
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
|
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
|
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_ids[0], 401)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
|
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
# Make another request providing the UI auth flow, but try to delete the
|
||||
# second device. This results in an error.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_ids[1],
|
||||
403,
|
||||
{
|
||||
|
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_does_not_offer_password_for_sso_user(self):
|
||||
login_resp = self.helper.login_via_oidc("username")
|
||||
user_tok = login_resp["access_token"]
|
||||
device_id = login_resp["device_id"]
|
||||
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
channel = self.delete_device(user_tok, device_id, 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||
|
||||
def test_does_not_offer_sso_for_password_user(self):
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||
|
||||
def test_offers_both_flows_for_upgraded_user(self):
|
||||
"""A user that had a password and then logged in with SSO should get both flows
|
||||
"""
|
||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
# we have no particular expectations of ordering here
|
||||
self.assertIn({"stages": ["m.login.password"]}, flows)
|
||||
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||
self.assertEqual(len(flows), 2)
|
||||
|
|
|
@ -259,6 +259,7 @@ def make_request(
|
|||
for k, v in custom_headers:
|
||||
req.requestHeaders.addRawHeader(k, v)
|
||||
|
||||
req.parseCookies()
|
||||
req.requestReceived(method, path, b"1.1")
|
||||
|
||||
if await_result:
|
||||
|
|
Loading…
Reference in a new issue