mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-11 12:31:58 +01:00
Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
da0e9f8efd
commit
707049c6ff
6 changed files with 195 additions and 34 deletions
1
changelog.d/12009.feature
Normal file
1
changelog.d/12009.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Enable modules to set a custom display name when registering a user.
|
|
@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
|
|||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||
any of the subsequent implementations of this callback. If every callback return `None`,
|
||||
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||
the authentication is denied.
|
||||
|
||||
### `on_logged_out`
|
||||
|
@ -162,10 +162,38 @@ return `None`.
|
|||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||
any of the subsequent implementations of this callback. If every callback return `None`,
|
||||
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||
the username provided by the user is used, if any (otherwise one is automatically
|
||||
generated).
|
||||
|
||||
### `get_displayname_for_registration`
|
||||
|
||||
_First introduced in Synapse v1.54.0_
|
||||
|
||||
```python
|
||||
async def get_displayname_for_registration(
|
||||
uia_results: Dict[str, Any],
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[str]
|
||||
```
|
||||
|
||||
Called when registering a new user. The module can return a display name to set for the
|
||||
user being registered by returning it as a string, or `None` if it doesn't wish to force a
|
||||
display name for this user.
|
||||
|
||||
This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
|
||||
has been completed by the user. It is not called when registering a user via SSO. It is
|
||||
passed two dictionaries, which include the information that the user has provided during
|
||||
the registration process. These dictionaries are identical to the ones passed to
|
||||
[`get_username_for_registration`](#get_username_for_registration), so refer to the
|
||||
documentation of this callback for more information about them.
|
||||
|
||||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||
any of the subsequent implementations of this callback. If every callback returns `None`,
|
||||
the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).
|
||||
|
||||
## `is_3pid_allowed`
|
||||
|
||||
_First introduced in Synapse v1.53.0_
|
||||
|
@ -196,7 +224,6 @@ The example module below implements authentication checkers for two different lo
|
|||
- Expects a `password` field to be sent to `/login`
|
||||
- Is checked by the method: `self.check_pass`
|
||||
|
||||
|
||||
```python
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
|
||||
|
|
|
@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
|||
[JsonDict, JsonDict],
|
||||
Awaitable[Optional[str]],
|
||||
]
|
||||
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
||||
[JsonDict, JsonDict],
|
||||
Awaitable[Optional[str]],
|
||||
]
|
||||
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
|
||||
|
||||
|
||||
|
@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
|
|||
self.get_username_for_registration_callbacks: List[
|
||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = []
|
||||
self.get_displayname_for_registration_callbacks: List[
|
||||
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = []
|
||||
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
||||
|
||||
# Mapping from login type to login parameters
|
||||
|
@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
|
|||
get_username_for_registration: Optional[
|
||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = None,
|
||||
get_displayname_for_registration: Optional[
|
||||
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = None,
|
||||
) -> None:
|
||||
# Register check_3pid_auth callback
|
||||
if check_3pid_auth is not None:
|
||||
|
@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
|
|||
get_username_for_registration,
|
||||
)
|
||||
|
||||
if get_displayname_for_registration is not None:
|
||||
self.get_displayname_for_registration_callbacks.append(
|
||||
get_displayname_for_registration,
|
||||
)
|
||||
|
||||
if is_3pid_allowed is not None:
|
||||
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
|
||||
|
||||
|
@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
|
|||
|
||||
return None
|
||||
|
||||
async def get_displayname_for_registration(
|
||||
self,
|
||||
uia_results: JsonDict,
|
||||
params: JsonDict,
|
||||
) -> Optional[str]:
|
||||
"""Defines the display name to use when registering the user, using the
|
||||
credentials and parameters provided during the UIA flow.
|
||||
|
||||
Stops at the first callback that returns a tuple containing at least one string.
|
||||
|
||||
Args:
|
||||
uia_results: The credentials provided during the UIA flow.
|
||||
params: The parameters provided by the registration request.
|
||||
|
||||
Returns:
|
||||
A tuple which first element is the display name, and the second is an MXC URL
|
||||
to the user's avatar.
|
||||
"""
|
||||
for callback in self.get_displayname_for_registration_callbacks:
|
||||
try:
|
||||
res = await callback(uia_results, params)
|
||||
|
||||
if isinstance(res, str):
|
||||
return res
|
||||
elif res is not None:
|
||||
# mypy complains that this line is unreachable because it assumes the
|
||||
# data returned by the module fits the expected type. We just want
|
||||
# to make sure this is the case.
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Ignoring non-string value returned by"
|
||||
" get_displayname_for_registration callback %s: %s",
|
||||
callback,
|
||||
res,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Module raised an exception in get_displayname_for_registration: %s",
|
||||
e,
|
||||
)
|
||||
raise SynapseError(code=500, msg="Internal Server Error")
|
||||
|
||||
return None
|
||||
|
||||
async def is_3pid_allowed(
|
||||
self,
|
||||
medium: str,
|
||||
|
|
|
@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
|
|||
from synapse.handlers.auth import (
|
||||
CHECK_3PID_AUTH_CALLBACK,
|
||||
CHECK_AUTH_CALLBACK,
|
||||
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
|
||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
|
||||
IS_3PID_ALLOWED_CALLBACK,
|
||||
ON_LOGGED_OUT_CALLBACK,
|
||||
|
@ -317,6 +318,9 @@ class ModuleApi:
|
|||
get_username_for_registration: Optional[
|
||||
GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = None,
|
||||
get_displayname_for_registration: Optional[
|
||||
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Registers callbacks for password auth provider capabilities.
|
||||
|
||||
|
@ -328,6 +332,7 @@ class ModuleApi:
|
|||
is_3pid_allowed=is_3pid_allowed,
|
||||
auth_checkers=auth_checkers,
|
||||
get_username_for_registration=get_username_for_registration,
|
||||
get_displayname_for_registration=get_displayname_for_registration,
|
||||
)
|
||||
|
||||
def register_background_update_controller_callbacks(
|
||||
|
|
|
@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
|
|||
session_id
|
||||
)
|
||||
|
||||
display_name = await (
|
||||
self.password_auth_provider.get_displayname_for_registration(
|
||||
auth_result, params
|
||||
)
|
||||
)
|
||||
|
||||
registered_user_id = await self.registration_handler.register_user(
|
||||
localpart=desired_username,
|
||||
password_hash=password_hash,
|
||||
guest_access_token=guest_access_token,
|
||||
threepid=threepid,
|
||||
default_display_name=display_name,
|
||||
address=client_addr,
|
||||
user_agent_ips=entries,
|
||||
)
|
||||
|
|
|
@ -84,7 +84,7 @@ class CustomAuthProvider:
|
|||
|
||||
def __init__(self, config, api: ModuleApi):
|
||||
api.register_password_auth_provider_callbacks(
|
||||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
|
||||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
|
||||
)
|
||||
|
||||
def check_auth(self, *args):
|
||||
|
@ -122,7 +122,7 @@ class PasswordCustomAuthProvider:
|
|||
auth_checkers={
|
||||
("test.login_type", ("test_field",)): self.check_auth,
|
||||
("m.login.password", ("password",)): self.check_auth,
|
||||
},
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
|
@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
account.register_servlets,
|
||||
]
|
||||
|
||||
CALLBACK_USERNAME = "get_username_for_registration"
|
||||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
|
||||
|
||||
def setUp(self):
|
||||
# we use a global mock device, so make sure we are starting with a clean slate
|
||||
mock_password_provider.reset_mock()
|
||||
|
@ -754,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"""Tests that the get_username_for_registration callback can define the username
|
||||
of a user when registering.
|
||||
"""
|
||||
self._setup_get_username_for_registration()
|
||||
self._setup_get_name_for_registration(
|
||||
callback_name=self.CALLBACK_USERNAME,
|
||||
)
|
||||
|
||||
username = "rin"
|
||||
channel = self.make_request(
|
||||
|
@ -777,30 +782,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
"""Tests that the get_username_for_registration callback is only called at the
|
||||
end of the UIA flow.
|
||||
"""
|
||||
m = self._setup_get_username_for_registration()
|
||||
m = self._setup_get_name_for_registration(
|
||||
callback_name=self.CALLBACK_USERNAME,
|
||||
)
|
||||
|
||||
# Initiate the UIA flow.
|
||||
username = "rin"
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": username, "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401)
|
||||
self.assertIn("session", channel.json_body)
|
||||
res = self._do_uia_assert_mock_not_called(username, m)
|
||||
|
||||
# Check that the callback hasn't been called yet.
|
||||
m.assert_not_called()
|
||||
|
||||
# Finish the UIA flow.
|
||||
session = channel.json_body["session"]
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
mxid = channel.json_body["user_id"]
|
||||
mxid = res["user_id"]
|
||||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
|
||||
|
||||
# Check that the callback has been called.
|
||||
|
@ -817,6 +806,56 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
self._test_3pid_allowed("rin", False)
|
||||
self._test_3pid_allowed("kitay", True)
|
||||
|
||||
def test_displayname(self):
|
||||
"""Tests that the get_displayname_for_registration callback can define the
|
||||
display name of a user when registering.
|
||||
"""
|
||||
self._setup_get_name_for_registration(
|
||||
callback_name=self.CALLBACK_DISPLAYNAME,
|
||||
)
|
||||
|
||||
username = "rin"
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/register",
|
||||
{
|
||||
"username": username,
|
||||
"password": "bar",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
},
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Our callback takes the username and appends "-foo" to it, check that's what we
|
||||
# have.
|
||||
user_id = UserID.from_string(channel.json_body["user_id"])
|
||||
display_name = self.get_success(
|
||||
self.hs.get_profile_handler().get_displayname(user_id)
|
||||
)
|
||||
|
||||
self.assertEqual(display_name, username + "-foo")
|
||||
|
||||
def test_displayname_uia(self):
|
||||
"""Tests that the get_displayname_for_registration callback is only called at the
|
||||
end of the UIA flow.
|
||||
"""
|
||||
m = self._setup_get_name_for_registration(
|
||||
callback_name=self.CALLBACK_DISPLAYNAME,
|
||||
)
|
||||
|
||||
username = "rin"
|
||||
res = self._do_uia_assert_mock_not_called(username, m)
|
||||
|
||||
user_id = UserID.from_string(res["user_id"])
|
||||
display_name = self.get_success(
|
||||
self.hs.get_profile_handler().get_displayname(user_id)
|
||||
)
|
||||
|
||||
self.assertEqual(display_name, username + "-foo")
|
||||
|
||||
# Check that the callback has been called.
|
||||
m.assert_called_once()
|
||||
|
||||
def _test_3pid_allowed(self, username: str, registration: bool):
|
||||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
|
||||
either /register or /account URLs depending on the arguments.
|
||||
|
@ -877,23 +916,47 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
|||
|
||||
m.assert_called_once_with("email", "bar@test.com", registration)
|
||||
|
||||
def _setup_get_username_for_registration(self) -> Mock:
|
||||
"""Registers a get_username_for_registration callback that appends "-foo" to the
|
||||
username the client is trying to register.
|
||||
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
|
||||
"""Registers either a get_username_for_registration callback or a
|
||||
get_displayname_for_registration callback that appends "-foo" to the username the
|
||||
client is trying to register.
|
||||
"""
|
||||
|
||||
async def get_username_for_registration(uia_results, params):
|
||||
async def callback(uia_results, params):
|
||||
self.assertIn(LoginType.DUMMY, uia_results)
|
||||
username = params["username"]
|
||||
return username + "-foo"
|
||||
|
||||
m = Mock(side_effect=get_username_for_registration)
|
||||
m = Mock(side_effect=callback)
|
||||
|
||||
password_auth_provider = self.hs.get_password_auth_provider()
|
||||
password_auth_provider.get_username_for_registration_callbacks.append(m)
|
||||
getattr(password_auth_provider, callback_name + "_callbacks").append(m)
|
||||
|
||||
return m
|
||||
|
||||
def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
|
||||
# Initiate the UIA flow.
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": username, "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
self.assertEqual(channel.code, 401)
|
||||
self.assertIn("session", channel.json_body)
|
||||
|
||||
# Check that the callback hasn't been called yet.
|
||||
m.assert_not_called()
|
||||
|
||||
# Finish the UIA flow.
|
||||
session = channel.json_body["session"]
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
return channel.json_body
|
||||
|
||||
def _get_login_flows(self) -> JsonDict:
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
|
Loading…
Reference in a new issue