Merge branch 'erikj/user_dedup' into release-v0.10.0

This commit is contained in:
Erik Johnston 2015-08-21 12:51:14 +01:00
commit 5dbd102470
4 changed files with 48 additions and 12 deletions

View file

@ -162,7 +162,8 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -283,23 +284,43 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
reg_handler = self.hs.get_handlers().registration_handler reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id) access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token) defer.returnValue((user_id, access_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks that user_id has passed password, raises LoginError if not.""" """Checks to see if a user with the given id exists. Will check case
user_info = yield self.store.get_user_by_id(user_id=user_id) insensitively, but will throw if there are multiple inexact matches.
if not user_info:
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"] if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not bcrypt.checkpw(password, stored_hash): if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View file

@ -57,8 +57,8 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if u: if users:
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",

View file

@ -83,9 +83,10 @@ class LoginRestServlet(ClientV1RestServlet):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(
user_id, self.hs.hostname).to_string() user_id, self.hs.hostname
).to_string()
token = yield self.handlers.auth_handler.login_with_password( user_id, token = yield self.handlers.auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])

View file

@ -98,6 +98,20 @@ class RegistrationStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn.fetchall())
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """