0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 04:13:52 +01:00

Merge pull request #221 from matrix-org/auth

Simplify LoginHander and AuthHandler
This commit is contained in:
Daniel Wagner-Hall 2015-08-14 17:02:22 +01:00
commit 30883d8409
9 changed files with 102 additions and 145 deletions

View file

@ -22,7 +22,6 @@ from .room import (
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler
@ -54,7 +53,6 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)

View file

@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None):
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args:
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of authed, dict, dict where authed is true if the client
A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
sess = self._get_session_info(sid)
session = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
@ -87,20 +94,19 @@ class AuthHandler(BaseHandler):
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
sess['clientdict'] = clientdict
self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
session['clientdict'] = clientdict
self._save_session(session)
elif 'clientdict' in session:
clientdict = session['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict)
(False, self._auth_dict_for_flows(flows, session), clientdict)
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
if 'creds' not in session:
session['creds'] = {}
creds = session['creds']
# check auth type currently being presented
if 'type' in authdict:
@ -109,15 +115,15 @@ class AuthHandler(BaseHandler):
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(sess)
self._save_session(session)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess)
self._remove_session(session)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess)
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@ -151,22 +157,13 @@ class AuthHandler(BaseHandler):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"]
user_id = authdict["user"]
password = authdict["password"]
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
self._check_password(user_id, password)
defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@ -270,6 +267,58 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
self._check_password(user_id, password)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Adding token %s for user %s", access_token, user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
def _check_password(self, user_id, password):
"""Checks that user_id has passed password, raises LoginError if not."""
user_info = yield self.store.get_user_by_id(user_id=user_id)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)

View file

@ -1,83 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
import bcrypt
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, token_id=None):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
user_id, token_id
)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)

View file

@ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self._generate_token(user_id)
token = self.generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id)
token = self.generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
token = self._generate_token(user_id)
token = self.generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id)
token = self.generate_token(user_id)
try:
yield self.store.register(
user_id=user_id,
@ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
def _generate_token(self, user_id):
def generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as
# query params.

View file

@ -94,17 +94,14 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
def remove_pushers_by_user(self, user_id):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access token %s",
user_id, not_access_token_id
"Removing all pushers for user %s",
user_id,
)
for p in all:
if (
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
if p['user_name'] == user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']

View file

@ -78,9 +78,8 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler
token = yield handler.login(
user=login_submission["user"],
token = yield self.handlers.auth_handler.login_with_password(
user_id=login_submission["user"],
password=login_submission["password"])
result = {

View file

@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body)
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.login_handler.set_password(
yield self.auth_handler.set_password(
user_id, new_password, None
)
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid(
yield self.auth_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],

View file

@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
@ -143,7 +142,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.login_handler.add_threepid(
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],

View file

@ -112,16 +112,16 @@ class RegistrationStore(SQLBaseStore):
})
@defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id):
def user_delete_access_tokens(self, user_id):
yield self.runInteraction(
"user_delete_access_tokens_apart_from",
self._user_delete_access_tokens_apart_from, user_id, token_id
"user_delete_access_tokens",
self._user_delete_access_tokens, user_id
)
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
def _user_delete_access_tokens(self, txn, user_id):
txn.execute(
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
(user_id, token_id)
"DELETE FROM access_tokens WHERE user_id = ?",
(user_id, )
)
@defer.inlineCallbacks