0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-23 01:14:02 +01:00

Implement password changing (finally) along with a start on making client/server auth more general.

This commit is contained in:
David Baker 2015-03-23 14:20:28 +00:00
parent 72d8406409
commit d98660a60d
7 changed files with 236 additions and 49 deletions

View file

@ -29,6 +29,7 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler
class Handlers(object): class Handlers(object):
@ -58,3 +59,4 @@ class Handlers(object):
hs, ApplicationServiceApi(hs) hs, ApplicationServiceApi(hs)
) )
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)

109
synapse/handlers/auth.py Normal file
View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
import logging
import bcrypt
logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
Args:
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
Returns:
A tuple of authed, dict where authed is true if the client
has successfully completed an auth flow. If it is true, the dict
contains the authenticated credentials of each stage.
If authed is false, the dictionary is the server response to the
login request and should be passed back to the client.
"""
types = {
LoginType.PASSWORD: self.check_password_auth
}
if 'auth' not in clientdict:
defer.returnValue((False, auth_dict_for_flows(flows)))
authdict = clientdict['auth']
# In future: support sessions & retrieve previously succeeded
# login types
creds = {}
# check auth type currently being presented
if 'type' not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
if authdict['type'] not in types:
raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield types[authdict['type']](authdict)
if result:
creds[authdict['type']] = result
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
defer.returnValue((True, creds))
ret = auth_dict_for_flows(flows)
ret['completed'] = creds.keys()
defer.returnValue((False, ret))
@defer.inlineCallbacks
def check_password_auth(self, authdict):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"]
password = authdict["password"]
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
def auth_dict_for_flows(flows):
return {
"flows": {"stages": f for f in flows}
}

View file

@ -69,48 +69,9 @@ class LoginHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def reset_password(self, user_id, email): def set_password(self, user_id, newpassword, token_id=None):
is_valid = yield self._check_valid_association(user_id, email) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid)
if is_valid:
try:
# send an email out
emailutils.send_email(
smtp_server=self.hs.config.email_smtp_server,
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
)
except EmailException as e:
logger.exception(e)
@defer.inlineCallbacks yield self.store.user_set_password_hash(user_id, password_hash)
def _check_valid_association(self, user_id, email): yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
identity = yield self._query_email(email) yield self.store.flush_user(user_id)
if identity and "mxid" in identity:
if identity["mxid"] == user_id:
defer.returnValue(True)
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
http_client = SimpleHttpClient(self.hs)
try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
)
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View file

@ -15,7 +15,8 @@
from . import ( from . import (
sync, sync,
filter filter,
password
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -32,3 +33,4 @@ class ClientV2AlphaRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource)
password.register_servlets(hs, client_resource)

View file

@ -17,9 +17,11 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,3 +38,13 @@ def client_v2_pattern(path_regex):
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View file

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request
import simplejson as json
import logging
logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_dict_from_request(request)
authed, result = yield self.auth_handler.check_auth([
[LoginType.PASSWORD]
], body)
if not authed:
defer.returnValue((401, result))
auth_user = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, client = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
if 'new_password' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = body['new_password']
self.login_handler.set_password(
user_id, new_password, client.token_id
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRestServlet(hs).register(http_server)

View file

@ -95,11 +95,36 @@ class RegistrationStore(SQLBaseStore):
"get_user_by_id", self.cursor_to_dict, query, user_id "get_user_by_id", self.cursor_to_dict, query, user_id
) )
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
return self._simple_update_one('users', {
'name': user_id
}, {
'password_hash': password_hash
})
def user_delete_access_tokens_apart_from(self, user_id, token_id):
return self._execute(
"delete_access_tokens_apart_from", None,
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
user_id, token_id
)
@defer.inlineCallbacks
def flush_user(self, user_id):
rows = yield self._execute(
'user_delete_access_tokens_apart_from', None,
"SELECT token FROM access_tokens WHERE user_id = ?",
user_id
)
for r in rows:
self.get_user_by_token.invalidate(r)
@cached() @cached()
# TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or
# change whether a user is a server admin, those will need to invoke
# store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token): def get_user_by_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.