From dcd5983fe4eb2a7c403c9969d2d8c12342440ff7 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 11 Aug 2015 16:54:06 +0100 Subject: [PATCH 1/4] Remove call to recently removed function in mock --- tests/test_distributor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 6a0095d85..8ed48cfb7 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase): yield d self.assertTrue(d.called) - observers[0].assert_called_once("Go") - observers[1].assert_called_once("Go") + observers[0].assert_called_once_with("Go") + observers[1].assert_called_once_with("Go") self.assertEquals(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], From 415c2f05491ce65a4fc34326519754cd1edd9c54 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 12 Aug 2015 15:49:37 +0100 Subject: [PATCH 2/4] Simplify LoginHander and AuthHandler * Merge LoginHandler -> AuthHandler * Add a bunch of documentation * Improve some naming * Remove unused branches I will start merging the actual logic of the two handlers shortly --- synapse/handlers/__init__.py | 2 - synapse/handlers/auth.py | 90 +++++++++++++++++++----- synapse/handlers/login.py | 83 ---------------------- synapse/handlers/register.py | 10 +-- synapse/push/pusherpool.py | 11 ++- synapse/rest/client/v1/login.py | 5 +- synapse/rest/client/v2_alpha/account.py | 8 +-- synapse/rest/client/v2_alpha/register.py | 3 +- synapse/storage/registration.py | 12 ++-- 9 files changed, 93 insertions(+), 131 deletions(-) delete mode 100644 synapse/handlers/login.py diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index dc5b6ef79..8725c3c42 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -22,7 +22,6 @@ from .room import ( from .message import MessageHandler from .events import EventStreamHandler, EventHandler from .federation import FederationHandler -from .login import LoginHandler from .profile import ProfileHandler from .presence import PresenceHandler from .directory import DirectoryHandler @@ -54,7 +53,6 @@ class Handlers(object): self.profile_handler = ProfileHandler(hs) self.presence_handler = PresenceHandler(hs) self.room_list_handler = RoomListHandler(hs) - self.login_handler = LoginHandler(hs) self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1ecf7fef1..1504b00d7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -47,17 +47,24 @@ class AuthHandler(BaseHandler): self.sessions = {} @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip=None): + def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration protocol and handles the login flow. + As a side effect, this function fills in the 'creds' key on the user's + session with a map, which maps each auth-type (str) to the relevant + identity authenticated by that auth-type (mostly str, but for captcha, bool). + Args: - flows: list of list of stages - authdict: The dictionary from the client root level, not the - 'auth' key: this method prompts for auth if none is sent. + flows (list): A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. + clientdict: The dictionary from the client root level, not the + 'auth' key: this method prompts for auth if none is sent. + clientip (str): The IP address of the client. Returns: - A tuple of authed, dict, dict where authed is true if the client + A tuple of (authed, dict, dict) where authed is true if the client has successfully completed an auth flow. If it is true, the first dict contains the authenticated credentials of each stage. @@ -75,7 +82,7 @@ class AuthHandler(BaseHandler): del clientdict['auth'] if 'session' in authdict: sid = authdict['session'] - sess = self._get_session_info(sid) + session = self._get_session_info(sid) if len(clientdict) > 0: # This was designed to allow the client to omit the parameters @@ -87,20 +94,19 @@ class AuthHandler(BaseHandler): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - sess['clientdict'] = clientdict - self._save_session(sess) - pass - elif 'clientdict' in sess: - clientdict = sess['clientdict'] + session['clientdict'] = clientdict + self._save_session(session) + elif 'clientdict' in session: + clientdict = session['clientdict'] if not authdict: defer.returnValue( - (False, self._auth_dict_for_flows(flows, sess), clientdict) + (False, self._auth_dict_for_flows(flows, session), clientdict) ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + if 'creds' not in session: + session['creds'] = {} + creds = session['creds'] # check auth type currently being presented if 'type' in authdict: @@ -109,15 +115,15 @@ class AuthHandler(BaseHandler): result = yield self.checkers[authdict['type']](authdict, clientip) if result: creds[authdict['type']] = result - self._save_session(sess) + self._save_session(session) for f in flows: if len(set(f) - set(creds.keys())) == 0: logger.info("Auth completed with creds: %r", creds) - self._remove_session(sess) + self._remove_session(session) defer.returnValue((True, creds, clientdict)) - ret = self._auth_dict_for_flows(flows, sess) + ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() defer.returnValue((False, ret, clientdict)) @@ -270,6 +276,54 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] + @defer.inlineCallbacks + def login_with_password(self, user_id, password): + """ + Authenticates the user with their username and password. + + Used only by the v1 login API. + + Args: + user_id (str): User ID + password (str): Password + Returns: + The access token for the user's session. + Raises: + StoreError if there was a problem storing the token. + LoginError if there was an authentication problem. + """ + user_info = yield self.store.get_user_by_id(user_id=user_id) + if not user_info: + logger.warn("Attempted to login as %s but they do not exist", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + stored_hash = user_info["password_hash"] + if not bcrypt.checkpw(password, stored_hash): + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_token(user_id) + logger.info("Adding token %s for user %s", access_token, user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + @defer.inlineCallbacks + def set_password(self, user_id, newpassword): + password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) + + yield self.store.user_set_password_hash(user_id, password_hash) + yield self.store.user_delete_access_tokens(user_id) + yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) + yield self.store.flush_user(user_id) + + @defer.inlineCallbacks + def add_threepid(self, user_id, medium, address, validated_at): + yield self.store.user_add_threepid( + user_id, medium, address, validated_at, + self.hs.get_clock().time_msec() + ) + def _save_session(self, session): # TODO: Persistent storage logger.debug("Saving session %s", session) diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py deleted file mode 100644 index 91d87d503..000000000 --- a/synapse/handlers/login.py +++ /dev/null @@ -1,83 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from ._base import BaseHandler -from synapse.api.errors import LoginError, Codes - -import bcrypt -import logging - -logger = logging.getLogger(__name__) - - -class LoginHandler(BaseHandler): - - def __init__(self, hs): - super(LoginHandler, self).__init__(hs) - self.hs = hs - - @defer.inlineCallbacks - def login(self, user, password): - """Login as the specified user with the specified password. - - Args: - user (str): The user ID. - password (str): The password. - Returns: - The newly allocated access token. - Raises: - StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. - """ - # TODO do this better, it can't go in __init__ else it cyclic loops - if not hasattr(self, "reg_handler"): - self.reg_handler = self.hs.get_handlers().registration_handler - - # pull out the hash for this user if they exist - user_info = yield self.store.get_user_by_id(user_id=user) - if not user_info: - logger.warn("Attempted to login as %s but they do not exist", user) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - stored_hash = user_info["password_hash"] - if bcrypt.checkpw(password, stored_hash): - # generate an access token and store it. - token = self.reg_handler._generate_token(user) - logger.info("Adding token %s for user %s", token, user) - yield self.store.add_access_token_to_user(user, token) - defer.returnValue(token) - else: - logger.warn("Failed password login for user %s", user) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - @defer.inlineCallbacks - def set_password(self, user_id, newpassword, token_id=None): - password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) - - yield self.store.user_set_password_hash(user_id, password_hash) - yield self.store.user_delete_access_tokens_apart_from(user_id, token_id) - yield self.hs.get_pusherpool().remove_pushers_by_user_access_token( - user_id, token_id - ) - yield self.store.flush_user(user_id) - - @defer.inlineCallbacks - def add_threepid(self, user_id, medium, address, validated_at): - yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() - ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f81d75017..39392d9fd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self._generate_token(user_id) + token = self.generate_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self._generate_token(user_id) + token = self.generate_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler): 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) - token = self._generate_token(user_id) + token = self.generate_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self._generate_token(user_id) + token = self.generate_token(user_id) try: yield self.store.register( user_id=user_id, @@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - def _generate_token(self, user_id): + def generate_token(self, user_id): # urlsafe variant uses _ and - so use . as the separator and replace # all =s with .s so http clients don't quote =s when it is used as # query params. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0ab2f6597..e012c565e 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,17 +94,14 @@ class PusherPool: self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): + def remove_pushers_by_user(self, user_id): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s except access token %s", - user_id, not_access_token_id + "Removing all pushers for user %s", + user_id, ) for p in all: - if ( - p['user_name'] == user_id and - p['access_token'] != not_access_token_id - ): + if p['user_name'] == user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 998d4d44c..694072693 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -78,9 +78,8 @@ class LoginRestServlet(ClientV1RestServlet): login_submission["user"] = UserID.create( login_submission["user"], self.hs.hostname).to_string() - handler = self.handlers.login_handler - token = yield handler.login( - user=login_submission["user"], + token = yield self.handlers.auth_handler.login_with_password( + user_id=login_submission["user"], password=login_submission["password"]) result = { diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index b082140f1..897c54b53 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_handlers().auth_handler - self.login_handler = hs.get_handlers().login_handler @defer.inlineCallbacks def on_POST(self, request): @@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet): authed, result, params = yield self.auth_handler.check_auth([ [LoginType.PASSWORD], [LoginType.EMAIL_IDENTITY] - ], body) + ], body, self.hs.get_ip_from_request(request)) if not authed: defer.returnValue((401, result)) @@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet): raise SynapseError(400, "", Codes.MISSING_PARAM) new_password = params['new_password'] - yield self.login_handler.set_password( + yield self.auth_handler.set_password( user_id, new_password, None ) @@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet): def __init__(self, hs): super(ThreepidRestServlet, self).__init__() self.hs = hs - self.login_handler = hs.get_handlers().login_handler self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet): logger.warn("Couldn't add 3pid: invalid response from ID sevrer") raise SynapseError(500, "Invalid response from ID Server") - yield self.login_handler.add_threepid( + yield self.auth_handler.add_threepid( auth_user.to_string(), threepid['medium'], threepid['address'], diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b5926f9ca..012c447e8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_handlers().auth_handler self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler - self.login_handler = hs.get_handlers().login_handler @defer.inlineCallbacks def on_POST(self, request): @@ -143,7 +142,7 @@ class RegisterRestServlet(RestServlet): if reqd not in threepid: logger.info("Can't add incomplete 3pid") else: - yield self.login_handler.add_threepid( + yield self.auth_handler.add_threepid( user_id, threepid['medium'], threepid['address'], diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4eaa088b3..d2d5b07cb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -111,16 +111,16 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens_apart_from(self, user_id, token_id): + def user_delete_access_tokens(self, user_id): yield self.runInteraction( - "user_delete_access_tokens_apart_from", - self._user_delete_access_tokens_apart_from, user_id, token_id + "user_delete_access_tokens", + self._user_delete_access_tokens, user_id ) - def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): + def _user_delete_access_tokens(self, txn, user_id): txn.execute( - "DELETE FROM access_tokens WHERE user_id = ? AND id != ?", - (user_id, token_id) + "DELETE FROM access_tokens WHERE user_id = ?", + (user_id, ) ) @defer.inlineCallbacks From 5ce903e2f7a4136b67ec16564e24244b00118367 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 12 Aug 2015 16:09:19 +0100 Subject: [PATCH 3/4] Merge password checking implementations --- synapse/handlers/auth.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1504b00d7..98d99dd0a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -157,22 +157,13 @@ class AuthHandler(BaseHandler): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - user = authdict["user"] + user_id = authdict["user"] password = authdict["password"] - if not user.startswith('@'): - user = UserID.create(user, self.hs.hostname).to_string() + if not user_id.startswith('@'): + user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_info = yield self.store.get_user_by_id(user_id=user) - if not user_info: - logger.warn("Attempted to login as %s but they do not exist", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - - stored_hash = user_info["password_hash"] - if bcrypt.checkpw(password, stored_hash): - defer.returnValue(user) - else: - logger.warn("Failed password login for user %s", user) - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + self._check_password(user_id, password) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -292,6 +283,16 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ + self._check_password(user_id, password) + + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_token(user_id) + logger.info("Adding token %s for user %s", access_token, user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + def _check_password(self, user_id, password): + """Checks that user_id has passed password, raises LoginError if not.""" user_info = yield self.store.get_user_by_id(user_id=user_id) if not user_info: logger.warn("Attempted to login as %s but they do not exist", user_id) @@ -302,12 +303,6 @@ class AuthHandler(BaseHandler): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_token(user_id) - logger.info("Adding token %s for user %s", access_token, user_id) - yield self.store.add_access_token_to_user(user_id, access_token) - defer.returnValue(access_token) - @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) From f9d4da7f4502aeefe7a5a6a9ec0d2682458e7834 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 17 Aug 2015 09:39:12 +0100 Subject: [PATCH 4/4] Fix bug where we were leaking None into state event lists --- synapse/storage/state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 185f88fd7..ecb62e6df 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -412,9 +412,10 @@ class StateStore(SQLBaseStore): full=(types is None), ) - results[group].update({ + # We replace here to remove all the entries with None values. + results[group] = { key: value for key, value in state_dict.items() if value - }) + } defer.returnValue(results)