0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 13:23:54 +01:00

Merge remote-tracking branch 'origin/develop' into rav/remove_unused_mocks

This commit is contained in:
Richard van der Hoff 2020-12-02 20:08:46 +00:00
commit 269ba1bc84
15 changed files with 367 additions and 111 deletions

1
changelog.d/8858.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.

1
changelog.d/8864.misc Normal file
View file

@ -0,0 +1 @@
Remove unused `FakeResponse` class from unit tests.

View file

@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args: Args:
hs (synapse.server.HomeServer): homeserver hs (synapse.server.HomeServer): homeserver
resource (TransportLayerServer): resource class to register to resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register. servlet_groups (list[str], optional): List of servlet groups to register.

View file

@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
self._sso_enabled = ( self._password_localdb_enabled = hs.config.password_localdb_enabled
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
# we keep this as a list despite the O(N^2) implication so that we can # we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first # keep PASSWORD first and avoid confusing clients which pick the first
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not. # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = [] login_types = []
if hs.config.password_localdb_enabled: if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD) login_types.append(LoginType.PASSWORD)
for provider in self.password_providers: for provider in self.password_providers:
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types self._supported_login_types = login_types
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
ui_auth_types = login_types.copy()
if self._sso_enabled:
ui_auth_types.append(LoginType.SSO)
self._supported_ui_auth_types = ui_auth_types
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows # build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types] supported_ui_auth_types = await self._get_available_ui_auth_types(
requester.user
)
flows = [[login_type] for login_type in supported_ui_auth_types]
try: try:
result, params, session_id = await self.check_ui_auth( result, params, session_id = await self.check_ui_auth(
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
raise raise
# find the completed login type # find the completed login type
for login_type in self._supported_ui_auth_types: for login_type in supported_ui_auth_types:
if login_type not in result: if login_type not in result:
continue continue
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
return params, session_id return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""
ui_auth_types = set()
# if the HS supports password auth, and the user has a non-null password, we
# support password auth
if self._password_localdb_enabled and self._password_enabled:
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
if lookupres:
_, password_hash = lookupres
if password_hash:
ui_auth_types.add(LoginType.PASSWORD)
# also allow auth from password providers
for provider in self.password_providers:
for t in provider.get_supported_login_types().keys():
if t == LoginType.PASSWORD and not self._password_enabled:
continue
ui_auth_types.add(t)
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
if await self.store.get_external_ids_by_user(user.to_string()):
ui_auth_types.add(LoginType.SSO)
# Our CAS impl does not (yet) correctly register users in user_external_ids,
# so always offer that if it's available.
if self.hs.config.cas.cas_enabled:
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
def get_enabled_auth_types(self): def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types """Return the enabled user-interactive authentication types
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
if result: if result:
return result return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True known_login_type = True
# we've already checked that there is a (valid) password field # we've already checked that there is a (valid) password field

View file

@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id", desc="get_user_by_external_id",
) )
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
"""Look up external ids for the given user
Args:
mxid: the MXID to be looked up
Returns:
Tuples of (auth_provider, external_id)
"""
res = await self.db_pool.simple_select_list(
table="user_external_ids",
keyvalues={"user_id": mxid},
retcols=("auth_provider", "external_id"),
desc="get_external_ids_by_user",
)
return [(r["auth_provider"], r["external_id"]) for r in res]
async def count_all_users(self): async def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
self.db_pool.updates.register_background_index_update(
"user_external_ids_user_id_idx",
index_name="user_external_ids_user_id_idx",
table="user_external_ids",
columns=["user_id"],
unique=False,
)
async def _background_update_set_deactivated_flag(self, progress, batch_size): async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them. for each of them.

View file

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5825, 'user_external_ids_user_id_idx', '{}');

View file

@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
from mock import Mock, patch from mock import Mock, patch
import attr
import pymacaroons import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException from synapse.handlers.sso import MappingException
from synapse.types import UserID from synapse.types import UserID
from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@attr.s
class FakeResponse:
code = attr.ib()
body = attr.ib()
phrase = attr.ib()
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
# These are a few constants that are used as config parameters in the tests. # These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/" ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id" CLIENT_ID = "test-client-id"

View file

@ -15,15 +15,16 @@
import json import json
from typing import Dict
from mock import ANY, Mock, call from mock import ANY, Mock, call
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport import server as federation_server from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -53,20 +54,7 @@ def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
def register_federation_servlets(hs, resource):
federation_server.register_servlets(
hs,
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(), config=hs.config.rc_federation
),
)
class TypingNotificationsTestCase(unittest.HomeserverTestCase): class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the # we mock out the keyring so as to skip the authentication check on the
# federation API call. # federation API call.
@ -89,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs return hs
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier() mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import attr import attr
@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
from synapse.app.generic_worker import ( from synapse.app.generic_worker import (
GenericWorkerReplicationHandler, GenericWorkerReplicationHandler,
@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis: if not hiredis:
skip = "Requires hiredis" skip = "Requires hiredis"
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
# build a replication server # build a replication server
server_factory = ReplicationStreamProtocolFactory(hs) server_factory = ReplicationStreamProtocolFactory(hs)
@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None self._client_transport = None
self._server_transport = None self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
return d
def _get_worker_hs_config(self) -> dict: def _get_worker_hs_config(self) -> dict:
config = self.default_config() config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker" config["worker_app"] = "synapse.app.generic_worker"

View file

@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,17 +17,23 @@
# limitations under the License. # limitations under the License.
import json import json
import re
import time import time
import urllib.parse
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from mock import patch
import attr import attr
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import JsonDict
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
from tests.test_utils import FakeResponse
@attr.s @attr.s
@ -344,3 +350,111 @@ class RestHelper:
) )
return channel.json_body return channel.json_body
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
"""Log in (as a new user) via OIDC
Returns the result of the final token login.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
"public_base_url".
Also requires the login servlet and the OIDC callback resource to be mounted at
the normal places.
"""
client_redirect_url = "https://x"
# first hit the redirect url (which will issue a cookie and state)
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
# that will redirect to the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
# param that synapse passes to the IdP via query params, and the cookie that
# synapse passes to the client.
assert channel.code == 302
oauth_uri = channel.headers.getRawHeaders("Location")[0]
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
redirect_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
)
cookies = {}
for h in channel.headers.getRawHeaders("Set-Cookie"):
parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v
# before we hit the callback uri, stub out some methods in the http client so
# that we don't have to handle full HTTPS requests.
# (expected url, json response) pairs, in the order we expect them.
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", {"sub": remote_user_id}),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
)
return resp
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
# now hit the callback URI with the right params and a made-up code
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
redirect_uri,
custom_headers=[
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
# expect a confirmation page
assert channel.code == 200
# fish the matrix login token out of the body of the confirmation page
m = re.search(
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
channel.result["body"].decode("utf-8"),
)
assert m
login_token = m.group(1)
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token and device id.
_, channel = make_request(
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
assert channel.code == 200
return channel.json_body
# an 'oidc_config' suitable for login_with_oidc.
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
"issuer": "https://issuer.test",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://z",
"token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Union from typing import List, Union
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register from synapse.rest.client.v2_alpha import auth, devices, register
from synapse.types import JsonDict from synapse.rest.oidc import OIDCResource
from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel from tests.server import FakeChannel
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets, register.register_servlets,
] ]
def default_config(self):
config = super().default_config()
# we enable OIDC as a way of testing SSO flows
oidc_config = {}
oidc_config.update(TEST_OIDC_CONFIG)
oidc_config["allow_existing_users"] = True
config["oidc_config"] = oidc_config
config["public_baseurl"] = "https://synapse.test"
return config
def create_resource_dict(self):
resource_dict = super().create_resource_dict()
# mount the OIDC resource at /_synapse/oidc
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
return resource_dict
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.user_pass = "pass" self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass) self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass) self.user_tok = self.login("test", self.user_pass)
def get_device_ids(self) -> List[str]: def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted. # Get the list of devices so one can be deleted.
request, channel = self.make_request( _, channel = self.make_request("GET", "devices", access_token=access_token,)
"GET", "devices", access_token=self.user_tok, self.assertEqual(channel.code, 200)
) # type: SynapseRequest, FakeChannel
# Get the ID of the device.
self.assertEqual(request.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]] return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device( def delete_device(
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" self,
access_token: str,
device: str,
expected_response: int,
body: Union[bytes, JsonDict] = b"",
) -> FakeChannel: ) -> FakeChannel:
"""Delete an individual device.""" """Delete an individual device."""
request, channel = self.make_request( request, channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=self.user_tok "DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
# Ensure the response is sane. # Ensure the response is sane.
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
""" """
Test user interactive authentication outside of registration. Test user interactive authentication outside of registration.
""" """
device_id = self.get_device_ids()[0] device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device. # Attempt to delete this device.
# Returns a 401 as per the spec # Returns a 401 as per the spec
channel = self.delete_device(device_id, 401) channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session # Grab the session
session = channel.json_body["session"] session = channel.json_body["session"]
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow. # Make another request providing the UI auth flow.
self.delete_device( self.delete_device(
self.user_tok,
device_id, device_id,
200, 200,
{ {
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works. UIA - check that still works.
""" """
device_id = self.get_device_ids()[0] device_id = self.get_device_ids(self.user_tok)[0]
channel = self.delete_device(device_id, 401) channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"] session = channel.json_body["session"]
# Make another request providing the UI auth flow. # Make another request providing the UI auth flow.
self.delete_device( self.delete_device(
self.user_tok,
device_id, device_id,
200, 200,
{ {
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login. # Create a second login.
self.login("test", self.user_pass) self.login("test", self.user_pass)
device_ids = self.get_device_ids() device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2) self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device. # Attempt to delete the first device.
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login. # Create a second login.
self.login("test", self.user_pass) self.login("test", self.user_pass)
device_ids = self.get_device_ids() device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2) self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device. # Attempt to delete the first device.
# Returns a 401 as per the spec # Returns a 401 as per the spec
channel = self.delete_device(device_ids[0], 401) channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session # Grab the session
session = channel.json_body["session"] session = channel.json_body["session"]
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the # Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error. # second device. This results in an error.
self.delete_device( self.delete_device(
self.user_tok,
device_ids[1], device_ids[1],
403, 403,
{ {
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
}, },
}, },
) )
def test_does_not_offer_password_for_sso_user(self):
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
channel = self.delete_device(user_tok, device_id, 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
def test_does_not_offer_sso_for_password_user(self):
# now call the device deletion API: we should get the option to auth with SSO
# and not password.
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
def test_offers_both_flows_for_upgraded_user(self):
"""A user that had a password and then logged in with SSO should get both flows
"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
device_ids = self.get_device_ids(self.user_tok)
channel = self.delete_device(self.user_tok, device_ids[0], 401)
flows = channel.json_body["flows"]
# we have no particular expectations of ordering here
self.assertIn({"stages": ["m.login.password"]}, flows)
self.assertIn({"stages": ["m.login.sso"]}, flows)
self.assertEqual(len(flows), 2)

View file

@ -18,41 +18,15 @@ import re
from mock import patch from mock import patch
import attr
from twisted.internet._resolver import HostResolution from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web._newclient import ResponseDone
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport
@attr.s
class FakeResponse:
version = attr.ib()
code = attr.ib()
phrase = attr.ib()
headers = attr.ib()
body = attr.ib()
absoluteURI = attr.ib()
@property
def request(self):
@attr.s
class FakeTransport:
absoluteURI = self.absoluteURI
return FakeTransport()
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
class URLPreviewTests(unittest.HomeserverTestCase): class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True hijack_auth = True

View file

@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix") and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse") and not path.startswith(b"/_synapse")
): ):
if path.startswith(b"/"):
path = path[1:]
path = b"/_matrix/client/r0/" + path path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
if not path.startswith(b"/"): if not path.startswith(b"/"):
path = b"/" + path path = b"/" + path
@ -258,6 +259,7 @@ def make_request(
for k, v in custom_headers: for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v) req.requestHeaders.addRawHeader(k, v)
req.parseCookies()
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
if await_result: if await_result:

View file

@ -22,6 +22,11 @@ import warnings
from asyncio import Future from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar from typing import Any, Awaitable, Callable, TypeVar
import attr
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
TV = TypeVar("TV") TV = TypeVar("TV")
@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore sys.unraisablehook = unraisablehook # type: ignore
return cleanup return cleanup
@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object
there is a similar class at treq.test.test_response, but it lacks a `phrase`
attribute, and didn't support deliverBody until recently.
"""
# HTTP response code
code = attr.ib(type=int)
# HTTP response phrase (eg b'OK' for a 200)
phrase = attr.ib(type=bytes)
# body of the response
body = attr.ib(type=bytes)
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))

View file

@ -20,7 +20,7 @@ import hmac
import inspect import inspect
import logging import logging
import time import time
from typing import Optional, Tuple, Type, TypeVar, Union, overload from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch from mock import Mock, patch
@ -46,6 +46,7 @@ from synapse.logging.context import (
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
""" """
Create a the root resource for the test server. Create a the root resource for the test server.
The default implementation creates a JsonResource and calls each function in The default calls `self.create_resource_dict` and builds the resultant dict
`servlets` to register servletes against it into a tree.
""" """
resource = JsonResource(self.hs) root_resource = Resource()
create_resource_tree(self.create_resource_dict(), root_resource)
return root_resource
def create_resource_dict(self) -> Dict[str, Resource]:
"""Create a resource tree for the test server
A resource tree is a mapping from path to twisted.web.resource.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servlets against it.
"""
servlet_resource = JsonResource(self.hs)
for servlet in self.servlets: for servlet in self.servlets:
servlet(self.hs, resource) servlet(self.hs, servlet_resource)
return {
return resource "/_matrix/client": servlet_resource,
"/_synapse/admin": servlet_resource,
}
def default_config(self): def default_config(self):
""" """
@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`. A federating homeserver that authenticates incoming requests as `other.example.com`.
""" """
def prepare(self, reactor, clock, homeserver): def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
return d
class TestTransportLayerServer(JsonResource):
"""A test implementation of TransportLayerServer
authenticates incoming requests as `other.example.com`.
"""
def __init__(self, hs):
super().__init__(hs)
class Authenticator: class Authenticator:
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
return succeed("other.example.com") return succeed("other.example.com")
authenticator = Authenticator()
ratelimiter = FederationRateLimiter( ratelimiter = FederationRateLimiter(
clock, hs.get_clock(),
FederationRateLimitConfig( FederationRateLimitConfig(
window_size=1, window_size=1,
sleep_limit=1, sleep_limit=1,
@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000, concurrent_requests=1000,
), ),
) )
federation_server.register_servlets(
homeserver, self.resource, Authenticator(), ratelimiter
)
return super().prepare(reactor, clock, homeserver) federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config): def override_config(extra_config):