mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 01:53:58 +01:00
Merge erikj/user_dedup to develop
This commit is contained in:
parent
a2355fae7e
commit
d3c0e48859
4 changed files with 50 additions and 12 deletions
|
@ -163,7 +163,8 @@ class AuthHandler(BaseHandler):
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
yield self._check_password(user_id, password)
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
|
self._check_password(user_id, password, password_hash)
|
||||||
defer.returnValue(user_id)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -280,27 +281,49 @@ class AuthHandler(BaseHandler):
|
||||||
password (str): Password
|
password (str): Password
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of:
|
A tuple of:
|
||||||
|
The user's ID.
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
The refresh token for the user's session.
|
The refresh token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
yield self._check_password(user_id, password)
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
|
self._check_password(user_id, password, password_hash)
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
logger.info("Logging in user %s", user_id)
|
||||||
access_token = yield self.issue_access_token(user_id)
|
access_token = yield self.issue_access_token(user_id)
|
||||||
refresh_token = yield self.issue_refresh_token(user_id)
|
refresh_token = yield self.issue_refresh_token(user_id)
|
||||||
defer.returnValue((access_token, refresh_token))
|
defer.returnValue((user_id, access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password(self, user_id, password):
|
def _find_user_id_and_pwd_hash(self, user_id):
|
||||||
"""Checks that user_id has passed password, raises LoginError if not."""
|
"""Checks to see if a user with the given id exists. Will check case
|
||||||
user_info = yield self.store.get_user_by_id(user_id=user_id)
|
insensitively, but will throw if there are multiple inexact matches.
|
||||||
if not user_info:
|
|
||||||
|
Returns:
|
||||||
|
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
|
||||||
|
"""
|
||||||
|
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
if not user_infos:
|
||||||
logger.warn("Attempted to login as %s but they do not exist", user_id)
|
logger.warn("Attempted to login as %s but they do not exist", user_id)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
stored_hash = user_info["password_hash"]
|
if len(user_infos) > 1:
|
||||||
|
if user_id not in user_infos:
|
||||||
|
logger.warn(
|
||||||
|
"Attempted to login as %s but it matches more than one user "
|
||||||
|
"inexactly: %r",
|
||||||
|
user_id, user_infos.keys()
|
||||||
|
)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
defer.returnValue((user_id, user_infos[user_id]))
|
||||||
|
else:
|
||||||
|
defer.returnValue(user_infos.popitem())
|
||||||
|
|
||||||
|
def _check_password(self, user_id, password, stored_hash):
|
||||||
|
"""Checks that user_id has passed password, raises LoginError if not."""
|
||||||
if not bcrypt.checkpw(password, stored_hash):
|
if not bcrypt.checkpw(password, stored_hash):
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
|
@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
u = yield self.store.get_user_by_id(user_id)
|
users = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if u:
|
if users:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"User ID already taken.",
|
"User ID already taken.",
|
||||||
|
|
|
@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(
|
user_id = UserID.create(
|
||||||
user_id, self.hs.hostname).to_string()
|
user_id, self.hs.hostname
|
||||||
|
).to_string()
|
||||||
|
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.handlers.auth_handler
|
||||||
access_token, refresh_token = yield auth_handler.login_with_password(
|
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"])
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_users_by_id_case_insensitive(self, user_id):
|
||||||
|
"""Gets users that match user_id case insensitively.
|
||||||
|
Returns a mapping of user_id -> password_hash.
|
||||||
|
"""
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT name, password_hash FROM users"
|
||||||
|
" WHERE lower(name) = lower(?)"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id,))
|
||||||
|
return dict(txn.fetchall())
|
||||||
|
|
||||||
|
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_set_password_hash(self, user_id, password_hash):
|
def user_set_password_hash(self, user_id, password_hash):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue