mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-06 01:14:07 +01:00
Refactor SsoHandler.get_mxid_from_sso
(#8900)
* Factor out _call_attribute_mapper and _register_mapped_user This is mostly an attempt to simplify `get_mxid_from_sso`. * Move mapping_lock down into SsoHandler.
This commit is contained in:
parent
1821f7cc26
commit
c64002e1c1
3 changed files with 51 additions and 28 deletions
1
changelog.d/8900.feature
Normal file
1
changelog.d/8900.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for allowing users to pick their own user ID during a single-sign-on login.
|
|
@ -34,7 +34,6 @@ from synapse.types import (
|
||||||
map_username_to_mxid_localpart,
|
map_username_to_mxid_localpart,
|
||||||
mxid_localpart_allowed_characters,
|
mxid_localpart_allowed_characters,
|
||||||
)
|
)
|
||||||
from synapse.util.async_helpers import Linearizer
|
|
||||||
from synapse.util.iterutils import chunk_seq
|
from synapse.util.iterutils import chunk_seq
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
|
||||||
# a map from saml session id to Saml2SessionData object
|
# a map from saml session id to Saml2SessionData object
|
||||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||||
|
|
||||||
# a lock on the mappings
|
|
||||||
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
|
|
||||||
|
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
|
||||||
def handle_redirect_request(
|
def handle_redirect_request(
|
||||||
|
@ -299,15 +295,14 @@ class SamlHandler(BaseHandler):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
return await self._sso_handler.get_mxid_from_sso(
|
||||||
return await self._sso_handler.get_mxid_from_sso(
|
self._auth_provider_id,
|
||||||
self._auth_provider_id,
|
remote_user_id,
|
||||||
remote_user_id,
|
user_agent,
|
||||||
user_agent,
|
ip_address,
|
||||||
ip_address,
|
saml_response_to_remapped_user_attributes,
|
||||||
saml_response_to_remapped_user_attributes,
|
grandfather_existing_users,
|
||||||
grandfather_existing_users,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def _remote_id_from_saml_response(
|
def _remote_id_from_saml_response(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -22,6 +22,7 @@ from twisted.web.http import Request
|
||||||
from synapse.api.errors import RedirectException
|
from synapse.api.errors import RedirectException
|
||||||
from synapse.http.server import respond_with_html
|
from synapse.http.server import respond_with_html
|
||||||
from synapse.types import UserID, contains_invalid_mxid_characters
|
from synapse.types import UserID, contains_invalid_mxid_characters
|
||||||
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -54,6 +55,9 @@ class SsoHandler:
|
||||||
self._error_template = hs.config.sso_error_template
|
self._error_template = hs.config.sso_error_template
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
# a lock on the mappings
|
||||||
|
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
||||||
|
|
||||||
def render_error(
|
def render_error(
|
||||||
self, request, error: str, error_description: Optional[str] = None
|
self, request, error: str, error_description: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -172,24 +176,38 @@ class SsoHandler:
|
||||||
to an additional page. (e.g. to prompt for more information)
|
to an additional page. (e.g. to prompt for more information)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# first of all, check if we already have a mapping for this user
|
# grab a lock while we try to find a mapping for this user. This seems...
|
||||||
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
|
# optimistic, especially for implementations that end up redirecting to
|
||||||
auth_provider_id, remote_user_id,
|
# interstitial pages.
|
||||||
)
|
with await self._mapping_lock.queue(auth_provider_id):
|
||||||
if previously_registered_user_id:
|
# first of all, check if we already have a mapping for this user
|
||||||
return previously_registered_user_id
|
previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
|
auth_provider_id, remote_user_id,
|
||||||
# Check for grandfathering of users.
|
)
|
||||||
if grandfather_existing_users:
|
|
||||||
previously_registered_user_id = await grandfather_existing_users()
|
|
||||||
if previously_registered_user_id:
|
if previously_registered_user_id:
|
||||||
# Future logins should also match this user ID.
|
|
||||||
await self._store.record_user_external_id(
|
|
||||||
auth_provider_id, remote_user_id, previously_registered_user_id
|
|
||||||
)
|
|
||||||
return previously_registered_user_id
|
return previously_registered_user_id
|
||||||
|
|
||||||
# Otherwise, generate a new user.
|
# Check for grandfathering of users.
|
||||||
|
if grandfather_existing_users:
|
||||||
|
previously_registered_user_id = await grandfather_existing_users()
|
||||||
|
if previously_registered_user_id:
|
||||||
|
# Future logins should also match this user ID.
|
||||||
|
await self._store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, previously_registered_user_id
|
||||||
|
)
|
||||||
|
return previously_registered_user_id
|
||||||
|
|
||||||
|
# Otherwise, generate a new user.
|
||||||
|
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
|
||||||
|
user_id = await self._register_mapped_user(
|
||||||
|
attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
|
||||||
|
)
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
async def _call_attribute_mapper(
|
||||||
|
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
|
) -> UserAttributes:
|
||||||
|
"""Call the attribute mapper function in a loop, until we get a unique userid"""
|
||||||
for i in range(self._MAP_USERNAME_RETRIES):
|
for i in range(self._MAP_USERNAME_RETRIES):
|
||||||
try:
|
try:
|
||||||
attributes = await sso_to_matrix_id_mapper(i)
|
attributes = await sso_to_matrix_id_mapper(i)
|
||||||
|
@ -227,7 +245,16 @@ class SsoHandler:
|
||||||
raise MappingException(
|
raise MappingException(
|
||||||
"Unable to generate a Matrix ID from the SSO response"
|
"Unable to generate a Matrix ID from the SSO response"
|
||||||
)
|
)
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
async def _register_mapped_user(
|
||||||
|
self,
|
||||||
|
attributes: UserAttributes,
|
||||||
|
auth_provider_id: str,
|
||||||
|
remote_user_id: str,
|
||||||
|
user_agent: str,
|
||||||
|
ip_address: str,
|
||||||
|
) -> str:
|
||||||
# Since the localpart is provided via a potentially untrusted module,
|
# Since the localpart is provided via a potentially untrusted module,
|
||||||
# ensure the MXID is valid before registering.
|
# ensure the MXID is valid before registering.
|
||||||
if contains_invalid_mxid_characters(attributes.localpart):
|
if contains_invalid_mxid_characters(attributes.localpart):
|
||||||
|
|
Loading…
Reference in a new issue