Support trying multiple localparts for OpenID Connect. (#8801)

Abstracts the SAML and OpenID Connect code which attempts to regenerate
the localpart of a matrix ID if it is already in use.
This commit is contained in:
Patrick Cloke 2020-11-25 10:04:22 -05:00 committed by GitHub
parent f38676d161
commit 4fd222ad70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 331 additions and 137 deletions

1
changelog.d/8801.feature Normal file
View file

@ -0,0 +1 @@
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.

View file

@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
information from.
- This method must return a string, which is the unique identifier for the
user. Commonly the ``sub`` claim of the response.
* `map_user_attributes(self, userinfo, token)`
* `map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`localpart` value, such as `john.doe1`.
- Returns a dictionary with two keys:
- localpart: A required string, used to generate the Matrix ID.
- displayname: An optional string, the display name for the user.

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@ -35,15 +36,10 @@ from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import (
JsonDict,
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
)
from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
# to be strings.
remote_user_id = str(remote_user_id)
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
mapper_signature = inspect.signature(
self._user_mapping_provider.map_user_attributes
)
if previously_registered_user_id:
return previously_registered_user_id
supports_failures = "failures" in mapper_signature.parameters
# Otherwise, generate a new user.
try:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token
)
except Exception as e:
raise MappingException(
"Could not extract user attributes from OIDC response: " + str(e)
)
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
"""
Call the mapping provider to map the OIDC userinfo and token to user attributes.
logger.debug(
"Retrieved user attributes from user mapping provider: %r", attributes
)
localpart = attributes["localpart"]
if not localpart:
raise MappingException(
"Error parsing OIDC response: OIDC mapping provider plugin "
"did not return a localpart value"
)
user_id = UserID(localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
if self._allow_existing_users:
if len(users) == 1:
registered_user_id = next(iter(users))
elif user_id in users:
registered_user_id = user_id
else:
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, list(users.keys())
)
)
This is backwards compatibility for abstraction for the SSO handler.
"""
if supports_failures:
attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token, failures
)
else:
# This mxid is taken
raise MappingException("mxid '{}' is already taken".format(user_id))
else:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))
# If the mapping provider does not support processing failures,
# do not continually generate the same Matrix ID since it will
# continue to already be in use. Note that the error raised is
# arbitrary and will get turned into a MappingException.
if failures:
raise RuntimeError(
"Mapping provider does not support de-duplicating Matrix IDs"
)
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=[(user_agent, ip_address)],
)
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
userinfo, token
)
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
return UserAttributes(**attributes)
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
self._allow_existing_users,
)
return registered_user_id
UserAttribute = TypedDict(
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
failures: How many times a call to this function with this
UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
self, userinfo: UserInfo, token: Token, failures: int
) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)
# Append suffix integer if last call to this function failed to produce
# a usable mxid.
localpart += str(failures) if failures else ""
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]

View file

@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
"Failed to extract remote user id from SAML response"
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
"""
Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
result = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
emails=result.get("emails"),
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@ -284,60 +295,14 @@ class SamlHandler(BaseHandler):
)
return registered_user_id
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i, client_redirect_url=client_redirect_url,
)
logger.debug(
"Retrieved SAML attributes from user mapping provider: %s "
"(attempt %d)",
attribute_dict,
i,
)
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)
displayname = attribute_dict.get("displayname")
emails = attribute_dict.get("emails", [])
# Check if this mxid already exists
if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self.server_name).to_string()
):
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(localpart):
raise MappingException("localpart is invalid: %s" % (localpart,))
logger.debug("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=[(user_agent, ip_address)],
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
)
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# a usable mxid.
localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead

View file

@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -29,9 +32,20 @@ class MappingException(Exception):
"""
@attr.s
class UserAttributes:
localpart = attr.ib(type=str)
display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list))
class SsoHandler(BaseHandler):
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
def render_error(
@ -94,3 +108,142 @@ class SsoHandler(BaseHandler):
# No match.
return None
async def get_mxid_from_sso(
self,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
allow_existing_users: bool = False,
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
This first checks if the SSO ID has previously been linked to a matrix ID,
if it has that matrix ID is returned regardless of the current mapping
logic.
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
If an unused localpart is generated, the user is registered from the
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
If allow_existing_users is true the mapping function is only called once
and results in:
1. The use of a previously registered matrix ID. In this case, the
SSO ID is linked to the matrix ID. (Note it is possible that
other SSO IDs are linked to the same matrix ID.)
2. An unused localpart, in which case the user is registered (as
discussed above).
3. An error if the generated localpart matches multiple pre-existing
matrix IDs. Generally this should not happen.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The unique identifier from the SSO provider.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
allow_existing_users: True if the localpart returned from the
mapping provider can be linked to an existing matrix ID.
Returns:
The user ID associated with the SSO response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
# first of all, check if we already have a mapping for this user
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
)
if previously_registered_user_id:
return previously_registered_user_id
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
except Exception as e:
raise MappingException(
"Could not extract user attributes from SSO response: " + str(e)
)
logger.debug(
"Retrieved user attributes from user mapping provider: %r (attempt %d)",
attributes,
i,
)
if not attributes.localpart:
raise MappingException(
"Error parsing SSO response: SSO mapping provider plugin "
"did not return a localpart value"
)
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
# Note, if allow_existing_users is true then the loop is guaranteed
# to end on the first iteration: either by matching an existing user,
# raising an error, or registering a new user. See the docstring for
# more in-depth an explanation.
if users and allow_existing_users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
previously_registered_user_id = next(iter(users))
elif user_id in users:
previously_registered_user_id = user_id
else:
# Do not attempt to continue generating Matrix IDs.
raise MappingException(
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
user_id, users
)
)
# Future logins should also match this user ID.
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
elif not users:
# This mxid is free
break
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart):
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
registered_user_id = await self._registration_handler.register_user(
localpart=attributes.localpart,
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
)
await self.store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id

View file

@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
async def map_user_attributes(self, userinfo, token, failures):
return {
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
# Reduce the number of attempts when generating MXIDs.
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
return hs
def metadata_edit(self, values):
@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
self.assertEqual(
str(e.value),
"Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(
store.register_user(user_id=user.to_string(), password_hash=None)
)
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
# in this case we'll just use the same username. A more realistic example
# would be subs which are email addresses, and mapping from the localpart
# of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
userinfo = {
"sub": "test1",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
self.assertEqual(mxid, "@test_user:test")
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
self.get_success(
@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "föö",
}
token = {}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
MappingException,
)
self.assertEqual(str(e.value), "localpart is invalid: föö")
@override_config(
{
"oidc_config": {
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
}
}
}
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
userinfo = {
"sub": "test",
"username": "test_user",
}
token = {}
mxid = self.get_success(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
)
)
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
# Register all of the potential users for a particular username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)
for i in range(1, 3):
self.get_success(
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
)
# Now attempt to map to a username, this will fail since all potential usernames are taken.
userinfo = {
"sub": "tester",
"username": "tester",
}
e = self.get_failure(
self.handler._map_userinfo_to_user(
userinfo, token, "user-agent", "10.10.10.10"
),
MappingException,
)
self.assertEqual(
str(e.value), "Unable to generate a Matrix ID from the SSO response"
)