0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 03:11:56 +01:00

Merge pull request #462 from matrix-org/daniel/guestupgrade

Allow guests to upgrade their accounts
This commit is contained in:
Daniel Wagner-Hall 2016-01-05 18:01:29 +00:00
commit 29e131df43
11 changed files with 93 additions and 40 deletions

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -583,7 +583,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self._get_user_from_macaroon(token)
ret = yield self.get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
@ -591,7 +591,7 @@ class Auth(object):
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str):
def get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -408,7 +408,7 @@ class AuthHandler(BaseHandler):
macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True)
return self._get_user_from_macaroon(macaroon)
return self.get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
@ -421,7 +421,7 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
def _get_user_from_macaroon(self, macaroon):
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -40,12 +40,13 @@ class RegistrationHandler(BaseHandler):
def __init__(self, hs):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
self.distributor = hs.get_distributor()
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs)
@defer.inlineCallbacks
def check_username(self, localpart):
def check_username(self, localpart, guest_access_token=None):
yield run_on_reactor()
if urllib.quote(localpart) != localpart:
@ -62,14 +63,29 @@ class RegistrationHandler(BaseHandler):
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
if not guest_access_token:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
user_data = yield self.auth.get_user_from_macaroon(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
"credentials for that user.",
errcode=Codes.FORBIDDEN,
)
@defer.inlineCallbacks
def register(self, localpart=None, password=None, generate_token=True):
def register(
self,
localpart=None,
password=None,
generate_token=True,
guest_access_token=None
):
"""Registers a new client on the server.
Args:
@ -89,7 +105,7 @@ class RegistrationHandler(BaseHandler):
password_hash = self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart)
yield self.check_username(localpart, guest_access_token=guest_access_token)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@ -100,7 +116,8 @@ class RegistrationHandler(BaseHandler):
yield self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash
password_hash=password_hash,
was_guest=guest_access_token is not None,
)
yield registered_user(self.distributor, user)

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
# Copyright 2015 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
# Copyright 2015 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -119,8 +119,13 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None)
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
yield self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token
)
if self.hs.config.enable_registration_captcha:
flows = [
@ -150,7 +155,8 @@ class RegisterRestServlet(RestServlet):
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password
password=new_password,
guest_access_token=guest_access_token,
)
if result and LoginType.EMAIL_IDENTITY in result:

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 27
SCHEMA_VERSION = 28
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -73,30 +73,39 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
def register(self, user_id, token, password_hash):
def register(self, user_id, token, password_hash, was_guest=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
"register",
self._register, user_id, token, password_hash
self._register, user_id, token, password_hash, was_guest
)
def _register(self, txn, user_id, token, password_hash):
def _register(self, txn, user_id, token, password_hash, was_guest):
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
try:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)",
[user_id, password_hash, now])
if was_guest:
txn.execute("UPDATE users SET"
" password_hash = ?,"
" upgrade_ts = ?"
" WHERE name = ?",
[password_hash, now, user_id])
else:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)",
[user_id, password_hash, now])
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Stores the timestamp when a user upgraded from a guest to a full user, if
* that happened.
*/
ALTER TABLE users ADD COLUMN upgrade_ts BIGINT;

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
# Copyright 2015 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@ -171,7 +171,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield self.auth._get_user_from_macaroon(serialized)
user_info = yield self.auth.get_user_from_macaroon(serialized)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@ -192,7 +192,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("User mismatch", cm.exception.msg)
@ -212,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("type = access")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("No user caveat", cm.exception.msg)
@ -234,7 +234,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@ -257,7 +257,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("cunning > fox")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@ -285,11 +285,11 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 5000 # seconds
yield self.auth._get_user_from_macaroon(macaroon.serialize())
yield self.auth.get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
# with self.assertRaises(AuthError) as cm:
# yield self.auth._get_user_from_macaroon(macaroon.serialize())
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
# self.assertEqual(401, cm.exception.code)
# self.assertIn("Invalid macaroon", cm.exception.msg)