mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 13:13:50 +01:00
Support trying multiple localparts for OpenID Connect. (#8801)
Abstracts the SAML and OpenID Connect code which attempts to regenerate the localpart of a matrix ID if it is already in use.
This commit is contained in:
parent
f38676d161
commit
4fd222ad70
6 changed files with 331 additions and 137 deletions
1
changelog.d/8801.feature
Normal file
1
changelog.d/8801.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
|
|
@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
|
||||||
information from.
|
information from.
|
||||||
- This method must return a string, which is the unique identifier for the
|
- This method must return a string, which is the unique identifier for the
|
||||||
user. Commonly the ``sub`` claim of the response.
|
user. Commonly the ``sub`` claim of the response.
|
||||||
* `map_user_attributes(self, userinfo, token)`
|
* `map_user_attributes(self, userinfo, token, failures)`
|
||||||
- This method must be async.
|
- This method must be async.
|
||||||
- Arguments:
|
- Arguments:
|
||||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||||
information from.
|
information from.
|
||||||
- `token` - A dictionary which includes information necessary to make
|
- `token` - A dictionary which includes information necessary to make
|
||||||
further requests to the OpenID provider.
|
further requests to the OpenID provider.
|
||||||
|
- `failures` - An `int` that represents the amount of times the returned
|
||||||
|
mxid localpart mapping has failed. This should be used
|
||||||
|
to create a deduplicated mxid localpart which should be
|
||||||
|
returned instead. For example, if this method returns
|
||||||
|
`john.doe` as the value of `localpart` in the returned
|
||||||
|
dict, and that is already taken on the homeserver, this
|
||||||
|
method will be called again with the same parameters but
|
||||||
|
with failures=1. The method should then return a different
|
||||||
|
`localpart` value, such as `john.doe1`.
|
||||||
- Returns a dictionary with two keys:
|
- Returns a dictionary with two keys:
|
||||||
- localpart: A required string, used to generate the Matrix ID.
|
- localpart: A required string, used to generate the Matrix ID.
|
||||||
- displayname: An optional string, the display name for the user.
|
- displayname: An optional string, the display name for the user.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
@ -35,15 +36,10 @@ from twisted.web.client import readBody
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.types import (
|
from synapse.types import JsonDict, map_username_to_mxid_localpart
|
||||||
JsonDict,
|
|
||||||
UserID,
|
|
||||||
contains_invalid_mxid_characters,
|
|
||||||
map_username_to_mxid_localpart,
|
|
||||||
)
|
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
|
||||||
# to be strings.
|
# to be strings.
|
||||||
remote_user_id = str(remote_user_id)
|
remote_user_id = str(remote_user_id)
|
||||||
|
|
||||||
# first of all, check if we already have a mapping for this user
|
# Older mapping providers don't accept the `failures` argument, so we
|
||||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
# try and detect support.
|
||||||
self._auth_provider_id, remote_user_id,
|
mapper_signature = inspect.signature(
|
||||||
|
self._user_mapping_provider.map_user_attributes
|
||||||
)
|
)
|
||||||
if previously_registered_user_id:
|
supports_failures = "failures" in mapper_signature.parameters
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
# Otherwise, generate a new user.
|
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
|
||||||
try:
|
"""
|
||||||
|
Call the mapping provider to map the OIDC userinfo and token to user attributes.
|
||||||
|
|
||||||
|
This is backwards compatibility for abstraction for the SSO handler.
|
||||||
|
"""
|
||||||
|
if supports_failures:
|
||||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||||
|
userinfo, token, failures
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If the mapping provider does not support processing failures,
|
||||||
|
# do not continually generate the same Matrix ID since it will
|
||||||
|
# continue to already be in use. Note that the error raised is
|
||||||
|
# arbitrary and will get turned into a MappingException.
|
||||||
|
if failures:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Mapping provider does not support de-duplicating Matrix IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
|
||||||
userinfo, token
|
userinfo, token
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise MappingException(
|
return UserAttributes(**attributes)
|
||||||
"Could not extract user attributes from OIDC response: " + str(e)
|
|
||||||
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
|
self._auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
oidc_response_to_user_attributes,
|
||||||
|
self._allow_existing_users,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Retrieved user attributes from user mapping provider: %r", attributes
|
|
||||||
)
|
|
||||||
|
|
||||||
localpart = attributes["localpart"]
|
UserAttributeDict = TypedDict(
|
||||||
if not localpart:
|
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
||||||
raise MappingException(
|
|
||||||
"Error parsing OIDC response: OIDC mapping provider plugin "
|
|
||||||
"did not return a localpart value"
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = UserID(localpart, self.server_name).to_string()
|
|
||||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
|
||||||
if users:
|
|
||||||
if self._allow_existing_users:
|
|
||||||
if len(users) == 1:
|
|
||||||
registered_user_id = next(iter(users))
|
|
||||||
elif user_id in users:
|
|
||||||
registered_user_id = user_id
|
|
||||||
else:
|
|
||||||
raise MappingException(
|
|
||||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
|
||||||
user_id, list(users.keys())
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# This mxid is taken
|
|
||||||
raise MappingException("mxid '{}' is already taken".format(user_id))
|
|
||||||
else:
|
|
||||||
# Since the localpart is provided via a potentially untrusted module,
|
|
||||||
# ensure the MXID is valid before registering.
|
|
||||||
if contains_invalid_mxid_characters(localpart):
|
|
||||||
raise MappingException("localpart is invalid: %s" % (localpart,))
|
|
||||||
|
|
||||||
# It's the first time this user is logging in and the mapped mxid was
|
|
||||||
# not taken, register the user
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
|
||||||
localpart=localpart,
|
|
||||||
default_display_name=attributes["display_name"],
|
|
||||||
user_agent_ips=[(user_agent, ip_address)],
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
|
||||||
)
|
|
||||||
return registered_user_id
|
|
||||||
|
|
||||||
|
|
||||||
UserAttribute = TypedDict(
|
|
||||||
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
|
|
||||||
)
|
)
|
||||||
C = TypeVar("C")
|
C = TypeVar("C")
|
||||||
|
|
||||||
|
@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def map_user_attributes(
|
async def map_user_attributes(
|
||||||
self, userinfo: UserInfo, token: Token
|
self, userinfo: UserInfo, token: Token, failures: int
|
||||||
) -> UserAttribute:
|
) -> UserAttributeDict:
|
||||||
"""Map a `UserInfo` object into user attributes.
|
"""Map a `UserInfo` object into user attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
userinfo: An object representing the user given by the OIDC provider
|
userinfo: An object representing the user given by the OIDC provider
|
||||||
token: A dict with the tokens returned by the provider
|
token: A dict with the tokens returned by the provider
|
||||||
|
failures: How many times a call to this function with this
|
||||||
|
UserInfo has resulted in a failure.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict containing the ``localpart`` and (optionally) the ``display_name``
|
A dict containing the ``localpart`` and (optionally) the ``display_name``
|
||||||
|
@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
return userinfo[self._config.subject_claim]
|
return userinfo[self._config.subject_claim]
|
||||||
|
|
||||||
async def map_user_attributes(
|
async def map_user_attributes(
|
||||||
self, userinfo: UserInfo, token: Token
|
self, userinfo: UserInfo, token: Token, failures: int
|
||||||
) -> UserAttribute:
|
) -> UserAttributeDict:
|
||||||
localpart = self._config.localpart_template.render(user=userinfo).strip()
|
localpart = self._config.localpart_template.render(user=userinfo).strip()
|
||||||
|
|
||||||
# Ensure only valid characters are included in the MXID.
|
# Ensure only valid characters are included in the MXID.
|
||||||
localpart = map_username_to_mxid_localpart(localpart)
|
localpart = map_username_to_mxid_localpart(localpart)
|
||||||
|
|
||||||
|
# Append suffix integer if last call to this function failed to produce
|
||||||
|
# a usable mxid.
|
||||||
|
localpart += str(failures) if failures else ""
|
||||||
|
|
||||||
display_name = None # type: Optional[str]
|
display_name = None # type: Optional[str]
|
||||||
if self._config.display_name_template is not None:
|
if self._config.display_name_template is not None:
|
||||||
display_name = self._config.display_name_template.render(
|
display_name = self._config.display_name_template.render(
|
||||||
|
@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||||
if display_name == "":
|
if display_name == "":
|
||||||
display_name = None
|
display_name = None
|
||||||
|
|
||||||
return UserAttribute(localpart=localpart, display_name=display_name)
|
return UserAttributeDict(localpart=localpart, display_name=display_name)
|
||||||
|
|
||||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||||
extras = {} # type: Dict[str, str]
|
extras = {} # type: Dict[str, str]
|
||||||
|
|
|
@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.saml2_config import SamlAttributeRequirement
|
from synapse.config.saml2_config import SamlAttributeRequirement
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException, UserAttributes
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
UserID,
|
UserID,
|
||||||
contains_invalid_mxid_characters,
|
|
||||||
map_username_to_mxid_localpart,
|
map_username_to_mxid_localpart,
|
||||||
mxid_localpart_allowed_characters,
|
mxid_localpart_allowed_characters,
|
||||||
)
|
)
|
||||||
|
@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
|
||||||
"Failed to extract remote user id from SAML response"
|
"Failed to extract remote user id from SAML response"
|
||||||
)
|
)
|
||||||
|
|
||||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
async def saml_response_to_remapped_user_attributes(
|
||||||
# first of all, check if we already have a mapping for this user
|
failures: int,
|
||||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
) -> UserAttributes:
|
||||||
self._auth_provider_id, remote_user_id,
|
"""
|
||||||
)
|
Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
|
||||||
if previously_registered_user_id:
|
|
||||||
return previously_registered_user_id
|
|
||||||
|
|
||||||
|
This is backwards compatibility for abstraction for the SSO handler.
|
||||||
|
"""
|
||||||
|
# Call the mapping provider.
|
||||||
|
result = self._user_mapping_provider.saml_response_to_user_attributes(
|
||||||
|
saml2_auth, failures, client_redirect_url
|
||||||
|
)
|
||||||
|
# Remap some of the results.
|
||||||
|
return UserAttributes(
|
||||||
|
localpart=result.get("mxid_localpart"),
|
||||||
|
display_name=result.get("displayname"),
|
||||||
|
emails=result.get("emails"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||||
# backwards-compatibility hack: see if there is an existing user with a
|
# backwards-compatibility hack: see if there is an existing user with a
|
||||||
# suitable mapping from the uid
|
# suitable mapping from the uid
|
||||||
if (
|
if (
|
||||||
|
@ -284,60 +295,14 @@ class SamlHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
return registered_user_id
|
return registered_user_id
|
||||||
|
|
||||||
# Map saml response to user attributes using the configured mapping provider
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
for i in range(1000):
|
self._auth_provider_id,
|
||||||
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
|
remote_user_id,
|
||||||
saml2_auth, i, client_redirect_url=client_redirect_url,
|
user_agent,
|
||||||
|
ip_address,
|
||||||
|
saml_response_to_remapped_user_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Retrieved SAML attributes from user mapping provider: %s "
|
|
||||||
"(attempt %d)",
|
|
||||||
attribute_dict,
|
|
||||||
i,
|
|
||||||
)
|
|
||||||
|
|
||||||
localpart = attribute_dict.get("mxid_localpart")
|
|
||||||
if not localpart:
|
|
||||||
raise MappingException(
|
|
||||||
"Error parsing SAML2 response: SAML mapping provider plugin "
|
|
||||||
"did not return a mxid_localpart value"
|
|
||||||
)
|
|
||||||
|
|
||||||
displayname = attribute_dict.get("displayname")
|
|
||||||
emails = attribute_dict.get("emails", [])
|
|
||||||
|
|
||||||
# Check if this mxid already exists
|
|
||||||
if not await self.store.get_users_by_id_case_insensitive(
|
|
||||||
UserID(localpart, self.server_name).to_string()
|
|
||||||
):
|
|
||||||
# This mxid is free
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Unable to generate a username in 1000 iterations
|
|
||||||
# Break and return error to the user
|
|
||||||
raise MappingException(
|
|
||||||
"Unable to generate a Matrix ID from the SAML response"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since the localpart is provided via a potentially untrusted module,
|
|
||||||
# ensure the MXID is valid before registering.
|
|
||||||
if contains_invalid_mxid_characters(localpart):
|
|
||||||
raise MappingException("localpart is invalid: %s" % (localpart,))
|
|
||||||
|
|
||||||
logger.debug("Mapped SAML user to local part %s", localpart)
|
|
||||||
registered_user_id = await self._registration_handler.register_user(
|
|
||||||
localpart=localpart,
|
|
||||||
default_display_name=displayname,
|
|
||||||
bind_emails=emails,
|
|
||||||
user_agent_ips=[(user_agent, ip_address)],
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.store.record_user_external_id(
|
|
||||||
self._auth_provider_id, remote_user_id, registered_user_id
|
|
||||||
)
|
|
||||||
return registered_user_id
|
|
||||||
|
|
||||||
def expire_sessions(self):
|
def expire_sessions(self):
|
||||||
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
|
||||||
to_expire = set()
|
to_expire = set()
|
||||||
|
@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the configured mapper for this mxid_source
|
# Use the configured mapper for this mxid_source
|
||||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
localpart = self._mxid_mapper(mxid_source)
|
||||||
|
|
||||||
# Append suffix integer if last call to this function failed to produce
|
# Append suffix integer if last call to this function failed to produce
|
||||||
# a usable mxid
|
# a usable mxid.
|
||||||
localpart = base_mxid_localpart + (str(failures) if failures else "")
|
localpart += str(failures) if failures else ""
|
||||||
|
|
||||||
# Retrieve the display name from the saml response
|
# Retrieve the display name from the saml response
|
||||||
# If displayname is None, the mxid_localpart will be used instead
|
# If displayname is None, the mxid_localpart will be used instead
|
||||||
|
|
|
@ -13,10 +13,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.handlers._base import BaseHandler
|
from synapse.handlers._base import BaseHandler
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
|
from synapse.types import UserID, contains_invalid_mxid_characters
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -29,9 +32,20 @@ class MappingException(Exception):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class UserAttributes:
|
||||||
|
localpart = attr.ib(type=str)
|
||||||
|
display_name = attr.ib(type=Optional[str], default=None)
|
||||||
|
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
||||||
|
|
||||||
|
|
||||||
class SsoHandler(BaseHandler):
|
class SsoHandler(BaseHandler):
|
||||||
|
# The number of attempts to ask the mapping provider for when generating an MXID.
|
||||||
|
_MAP_USERNAME_RETRIES = 1000
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._error_template = hs.config.sso_error_template
|
self._error_template = hs.config.sso_error_template
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
|
@ -94,3 +108,142 @@ class SsoHandler(BaseHandler):
|
||||||
|
|
||||||
# No match.
|
# No match.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_mxid_from_sso(
|
||||||
|
self,
|
||||||
|
auth_provider_id: str,
|
||||||
|
remote_user_id: str,
|
||||||
|
user_agent: str,
|
||||||
|
ip_address: str,
|
||||||
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
|
allow_existing_users: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Given an SSO ID, retrieve the user ID for it and possibly register the user.
|
||||||
|
|
||||||
|
This first checks if the SSO ID has previously been linked to a matrix ID,
|
||||||
|
if it has that matrix ID is returned regardless of the current mapping
|
||||||
|
logic.
|
||||||
|
|
||||||
|
The mapping function is called (potentially multiple times) to generate
|
||||||
|
a localpart for the user.
|
||||||
|
|
||||||
|
If an unused localpart is generated, the user is registered from the
|
||||||
|
given user-agent and IP address and the SSO ID is linked to this matrix
|
||||||
|
ID for subsequent calls.
|
||||||
|
|
||||||
|
If allow_existing_users is true the mapping function is only called once
|
||||||
|
and results in:
|
||||||
|
|
||||||
|
1. The use of a previously registered matrix ID. In this case, the
|
||||||
|
SSO ID is linked to the matrix ID. (Note it is possible that
|
||||||
|
other SSO IDs are linked to the same matrix ID.)
|
||||||
|
2. An unused localpart, in which case the user is registered (as
|
||||||
|
discussed above).
|
||||||
|
3. An error if the generated localpart matches multiple pre-existing
|
||||||
|
matrix IDs. Generally this should not happen.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_provider_id: A unique identifier for this SSO provider, e.g.
|
||||||
|
"oidc" or "saml".
|
||||||
|
remote_user_id: The unique identifier from the SSO provider.
|
||||||
|
user_agent: The user agent of the client making the request.
|
||||||
|
ip_address: The IP address of the client making the request.
|
||||||
|
sso_to_matrix_id_mapper: A callable to generate the user attributes.
|
||||||
|
The only parameter is an integer which represents the amount of
|
||||||
|
times the returned mxid localpart mapping has failed.
|
||||||
|
allow_existing_users: True if the localpart returned from the
|
||||||
|
mapping provider can be linked to an existing matrix ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The user ID associated with the SSO response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MappingException if there was a problem mapping the response to a user.
|
||||||
|
RedirectException: some mapping providers may raise this if they need
|
||||||
|
to redirect to an interstitial page.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# first of all, check if we already have a mapping for this user
|
||||||
|
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
|
auth_provider_id, remote_user_id,
|
||||||
|
)
|
||||||
|
if previously_registered_user_id:
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
# Otherwise, generate a new user.
|
||||||
|
for i in range(self._MAP_USERNAME_RETRIES):
|
||||||
|
try:
|
||||||
|
attributes = await sso_to_matrix_id_mapper(i)
|
||||||
|
except Exception as e:
|
||||||
|
raise MappingException(
|
||||||
|
"Could not extract user attributes from SSO response: " + str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Retrieved user attributes from user mapping provider: %r (attempt %d)",
|
||||||
|
attributes,
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not attributes.localpart:
|
||||||
|
raise MappingException(
|
||||||
|
"Error parsing SSO response: SSO mapping provider plugin "
|
||||||
|
"did not return a localpart value"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this mxid already exists
|
||||||
|
user_id = UserID(attributes.localpart, self.server_name).to_string()
|
||||||
|
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
# Note, if allow_existing_users is true then the loop is guaranteed
|
||||||
|
# to end on the first iteration: either by matching an existing user,
|
||||||
|
# raising an error, or registering a new user. See the docstring for
|
||||||
|
# more in-depth an explanation.
|
||||||
|
if users and allow_existing_users:
|
||||||
|
# If an existing matrix ID is returned, then use it.
|
||||||
|
if len(users) == 1:
|
||||||
|
previously_registered_user_id = next(iter(users))
|
||||||
|
elif user_id in users:
|
||||||
|
previously_registered_user_id = user_id
|
||||||
|
else:
|
||||||
|
# Do not attempt to continue generating Matrix IDs.
|
||||||
|
raise MappingException(
|
||||||
|
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||||
|
user_id, users
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Future logins should also match this user ID.
|
||||||
|
await self.store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
elif not users:
|
||||||
|
# This mxid is free
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Unable to generate a username in 1000 iterations
|
||||||
|
# Break and return error to the user
|
||||||
|
raise MappingException(
|
||||||
|
"Unable to generate a Matrix ID from the SSO response"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since the localpart is provided via a potentially untrusted module,
|
||||||
|
# ensure the MXID is valid before registering.
|
||||||
|
if contains_invalid_mxid_characters(attributes.localpart):
|
||||||
|
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
|
||||||
|
|
||||||
|
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
|
||||||
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
|
localpart=attributes.localpart,
|
||||||
|
default_display_name=attributes.display_name,
|
||||||
|
bind_emails=attributes.emails,
|
||||||
|
user_agent_ips=[(user_agent, ip_address)],
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, registered_user_id
|
||||||
|
)
|
||||||
|
return registered_user_id
|
||||||
|
|
|
@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
|
||||||
return {"phone": userinfo["phone"]}
|
return {"phone": userinfo["phone"]}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMappingProviderFailures(TestMappingProvider):
|
||||||
|
async def map_user_attributes(self, userinfo, token, failures):
|
||||||
|
return {
|
||||||
|
"localpart": userinfo["username"] + (str(failures) if failures else ""),
|
||||||
|
"display_name": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def simple_async_mock(return_value=None, raises=None):
|
def simple_async_mock(return_value=None, raises=None):
|
||||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
||||||
async def cb(*args, **kwargs):
|
async def cb(*args, **kwargs):
|
||||||
|
@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.render_error = Mock(return_value=None)
|
self.render_error = Mock(return_value=None)
|
||||||
self.handler._sso_handler.render_error = self.render_error
|
self.handler._sso_handler.render_error = self.render_error
|
||||||
|
|
||||||
|
# Reduce the number of attempts when generating MXIDs.
|
||||||
|
self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
def metadata_edit(self, values):
|
def metadata_edit(self, values):
|
||||||
|
@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
),
|
),
|
||||||
MappingException,
|
MappingException,
|
||||||
)
|
)
|
||||||
self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
|
self.assertEqual(
|
||||||
|
str(e.value),
|
||||||
|
"Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
|
||||||
|
)
|
||||||
|
|
||||||
@override_config({"oidc_config": {"allow_existing_users": True}})
|
@override_config({"oidc_config": {"allow_existing_users": True}})
|
||||||
def test_map_userinfo_to_existing_user(self):
|
def test_map_userinfo_to_existing_user(self):
|
||||||
|
@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
self.get_success(
|
self.get_success(
|
||||||
store.register_user(user_id=user.to_string(), password_hash=None)
|
store.register_user(user_id=user.to_string(), password_hash=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Map a user via SSO.
|
||||||
userinfo = {
|
userinfo = {
|
||||||
"sub": "test",
|
"sub": "test",
|
||||||
"username": "test_user",
|
"username": "test_user",
|
||||||
|
@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(mxid, "@test_user:test")
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
|
# Note that a second SSO user can be mapped to the same Matrix ID. (This
|
||||||
|
# requires a unique sub, but something that maps to the same matrix ID,
|
||||||
|
# in this case we'll just use the same username. A more realistic example
|
||||||
|
# would be subs which are email addresses, and mapping from the localpart
|
||||||
|
# of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
|
||||||
|
userinfo = {
|
||||||
|
"sub": "test1",
|
||||||
|
"username": "test_user",
|
||||||
|
}
|
||||||
|
token = {}
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_userinfo_to_user(
|
||||||
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(mxid, "@test_user:test")
|
||||||
|
|
||||||
# Register some non-exact matching cases.
|
# Register some non-exact matching cases.
|
||||||
user2 = UserID.from_string("@TEST_user_2:test")
|
user2 = UserID.from_string("@TEST_user_2:test")
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
"username": "föö",
|
"username": "föö",
|
||||||
}
|
}
|
||||||
token = {}
|
token = {}
|
||||||
|
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.handler._map_userinfo_to_user(
|
self.handler._map_userinfo_to_user(
|
||||||
userinfo, token, "user-agent", "10.10.10.10"
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
MappingException,
|
MappingException,
|
||||||
)
|
)
|
||||||
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
self.assertEqual(str(e.value), "localpart is invalid: föö")
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"oidc_config": {
|
||||||
|
"user_mapping_provider": {
|
||||||
|
"module": __name__ + ".TestMappingProviderFailures"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_map_userinfo_to_user_retries(self):
|
||||||
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
||||||
|
store = self.hs.get_datastore()
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||||
|
)
|
||||||
|
userinfo = {
|
||||||
|
"sub": "test",
|
||||||
|
"username": "test_user",
|
||||||
|
}
|
||||||
|
token = {}
|
||||||
|
mxid = self.get_success(
|
||||||
|
self.handler._map_userinfo_to_user(
|
||||||
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# test_user is already taken, so test_user1 gets registered instead.
|
||||||
|
self.assertEqual(mxid, "@test_user1:test")
|
||||||
|
|
||||||
|
# Register all of the potential users for a particular username.
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester:test", password_hash=None)
|
||||||
|
)
|
||||||
|
for i in range(1, 3):
|
||||||
|
self.get_success(
|
||||||
|
store.register_user(user_id="@tester%d:test" % i, password_hash=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attempt to map to a username, this will fail since all potential usernames are taken.
|
||||||
|
userinfo = {
|
||||||
|
"sub": "tester",
|
||||||
|
"username": "tester",
|
||||||
|
}
|
||||||
|
e = self.get_failure(
|
||||||
|
self.handler._map_userinfo_to_user(
|
||||||
|
userinfo, token, "user-agent", "10.10.10.10"
|
||||||
|
),
|
||||||
|
MappingException,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(e.value), "Unable to generate a Matrix ID from the SSO response"
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue