mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-22 13:13:58 +01:00
Create user with expiry
- Add unittests for client, api and handler Signed-off-by: Negar Fazeli <negar.fazeli@ericsson.com>
This commit is contained in:
parent
ae1af262f6
commit
40aa6e8349
10 changed files with 301 additions and 9 deletions
|
@ -612,7 +612,8 @@ class Auth(object):
|
|||
def get_user_from_macaroon(self, macaroon_str):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
self.validate_macaroon(macaroon, "access", False)
|
||||
|
||||
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
|
|
|
@ -57,6 +57,8 @@ class KeyConfig(Config):
|
|||
seed = self.signing_key[0].seed
|
||||
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||
**kwargs):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
@ -69,6 +71,9 @@ class KeyConfig(Config):
|
|||
return """\
|
||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||
|
||||
# Used to enable access token expiration.
|
||||
expire_access_token: False
|
||||
|
||||
## Signing Keys ##
|
||||
|
||||
# Path to the signing key to sign messages with
|
||||
|
|
|
@ -32,6 +32,7 @@ class RegistrationConfig(Config):
|
|||
)
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||
|
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
|||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
||||
# Sets the expiry for the short term user creation in
|
||||
# milliseconds. For instance the bellow duration is two weeks
|
||||
# in milliseconds.
|
||||
user_creation_max_duration: 1209600000
|
||||
|
||||
# Set the number of bcrypt rounds used to generate password hash.
|
||||
# Larger numbers increase the work factor needed to generate the hash.
|
||||
# The default number of rounds is 12.
|
||||
|
|
|
@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
|
|||
))
|
||||
return m.serialize()
|
||||
|
||||
def generate_short_term_login_token(self, user_id):
|
||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (2 * 60 * 1000)
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
|
|
|
@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
|
|||
)
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
"""Creates a new user or returns an access token for an existing one
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
yield self.check_username(localpart)
|
||||
except SynapseError as e:
|
||||
if e.errcode == Codes.USER_IN_USE:
|
||||
need_register = False
|
||||
else:
|
||||
raise
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.flush_user(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
yield profile_handler.set_displayname(
|
||||
user, user, displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_handlers().auth_handler
|
||||
|
||||
|
|
|
@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
|
||||
|
||||
class CreateUserRestServlet(ClientV1RestServlet):
|
||||
"""Handles user creation via a server-to-server interface
|
||||
"""
|
||||
|
||||
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user_json = parse_json_object_from_request(request)
|
||||
|
||||
if "access_token" not in request.args:
|
||||
raise SynapseError(400, "Expected application service token.")
|
||||
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request.args["access_token"][0]
|
||||
)
|
||||
if not app_service:
|
||||
raise SynapseError(403, "Invalid application service token.")
|
||||
|
||||
logger.debug("creating user: %s", user_json)
|
||||
|
||||
response = yield self._do_create(user_json)
|
||||
|
||||
defer.returnValue((200, response))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return 403, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_create(self, user_json):
|
||||
yield run_on_reactor()
|
||||
|
||||
if "localpart" not in user_json:
|
||||
raise SynapseError(400, "Expected 'localpart' key.")
|
||||
|
||||
if "displayname" not in user_json:
|
||||
raise SynapseError(400, "Expected 'displayname' key.")
|
||||
|
||||
if "duration_seconds" not in user_json:
|
||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||
|
||||
localpart = user_json["localpart"].encode("utf-8")
|
||||
displayname = user_json["displayname"].encode("utf-8")
|
||||
duration_seconds = 0
|
||||
try:
|
||||
duration_seconds = int(user_json["duration_seconds"])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
CreateUserRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("time < 1") # ms
|
||||
|
||||
self.hs.clock.now = 5000 # seconds
|
||||
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
self.hs.config.expire_access_token = True
|
||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||
# validate expiration (and remove the above line, which will start
|
||||
# throwing).
|
||||
# with self.assertRaises(AuthError) as cm:
|
||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
# self.assertEqual(401, cm.exception.code)
|
||||
# self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
|
67
tests/handlers/test_register.py
Normal file
67
tests/handlers/test_register.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from .. import unittest
|
||||
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class RegistrationHandlers(object):
|
||||
def __init__(self, hs):
|
||||
self.registration_handler = RegistrationHandler(hs)
|
||||
|
||||
|
||||
class RegistrationTestCase(unittest.TestCase):
|
||||
""" Tests the RegistrationHandler. """
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_distributor = Mock()
|
||||
self.mock_distributor.declare("registered_user")
|
||||
self.mock_captcha_client = Mock()
|
||||
hs = yield setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True)
|
||||
hs.handlers = RegistrationHandlers(hs)
|
||||
self.handler = hs.get_handlers().registration_handler
|
||||
hs.get_handlers().profile_handler = Mock()
|
||||
self.mock_handler = Mock(spec=[
|
||||
"generate_short_term_login_token",
|
||||
])
|
||||
|
||||
hs.get_handlers().auth_handler = self.mock_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
"""
|
||||
Returns:
|
||||
The user doess not exist in this case so it will register and log it in
|
||||
"""
|
||||
duration_ms = 200
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
user_id = "@someone:test"
|
||||
mock_token = self.mock_handler.generate_short_term_login_token
|
||||
mock_token.return_value = 'secret'
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
local_part, display_name, duration_ms)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
88
tests/rest/client/v1/test_register.py
Normal file
88
tests/rest/client/v1/test_register.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||
from twisted.internet import defer
|
||||
from mock import Mock
|
||||
from tests import unittest
|
||||
import json
|
||||
|
||||
|
||||
class CreateUserServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/client/api/v1/createUser'
|
||||
)
|
||||
self.request.args = {}
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: defer.succeed(self.appservice))
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None, None)
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "supergbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.config.enable_registration = True
|
||||
# init the thing we're testing
|
||||
self.servlet = CreateUserRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_createuser_with_valid_user(self):
|
||||
user_id = "@someone:interesting"
|
||||
token = "my token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200
|
||||
})
|
||||
|
||||
self.registration_handler.get_or_create_user = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
|
@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
config.event_cache_size = 1
|
||||
config.enable_registration = True
|
||||
config.macaroon_secret_key = "not even a little secret"
|
||||
config.expire_access_token = False
|
||||
config.server_name = "server.under.test"
|
||||
config.trusted_third_party_id_servers = []
|
||||
config.room_invite_state_types = []
|
||||
|
|
Loading…
Reference in a new issue