mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 09:54:00 +01:00
move LDAP authentication to AuthenticationHandler
This commit is contained in:
parent
3d95405e5f
commit
7b9319b1c8
2 changed files with 48 additions and 61 deletions
|
@ -30,6 +30,8 @@ import simplejson
|
||||||
|
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
|
|
||||||
|
import ldap
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -49,6 +51,15 @@ class AuthHandler(BaseHandler):
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||||
|
|
||||||
|
self.ldap_enabled = hs.config.ldap_enabled
|
||||||
|
self.ldap_server = hs.config.ldap_server
|
||||||
|
self.ldap_port = hs.config.ldap_port
|
||||||
|
self.ldap_search_base = hs.config.ldap_search_base
|
||||||
|
self.ldap_search_property = hs.config.ldap_search_property
|
||||||
|
self.ldap_email_property = hs.config.ldap_email_property
|
||||||
|
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
"""
|
"""
|
||||||
|
@ -215,8 +226,8 @@ class AuthHandler(BaseHandler):
|
||||||
if not user_id.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
self._check_password(user_id, password)
|
||||||
self._check_password(user_id, password, password_hash)
|
|
||||||
defer.returnValue(user_id)
|
defer.returnValue(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -340,8 +351,8 @@ class AuthHandler(BaseHandler):
|
||||||
StoreError if there was a problem storing the token.
|
StoreError if there was a problem storing the token.
|
||||||
LoginError if there was an authentication problem.
|
LoginError if there was an authentication problem.
|
||||||
"""
|
"""
|
||||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
|
||||||
self._check_password(user_id, password, password_hash)
|
self._check_password(user_id, password)
|
||||||
|
|
||||||
logger.info("Logging in user %s", user_id)
|
logger.info("Logging in user %s", user_id)
|
||||||
access_token = yield self.issue_access_token(user_id)
|
access_token = yield self.issue_access_token(user_id)
|
||||||
|
@ -407,12 +418,43 @@ class AuthHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
defer.returnValue(user_infos.popitem())
|
defer.returnValue(user_infos.popitem())
|
||||||
|
|
||||||
def _check_password(self, user_id, password, stored_hash):
|
def _check_password(self, user_id, password):
|
||||||
"""Checks that user_id has passed password, raises LoginError if not."""
|
"""Checks that user_id has passed password, raises LoginError if not."""
|
||||||
if not self.validate_hash(password, stored_hash):
|
|
||||||
|
if not (self._check_ldap_password(user_id, password) or self._check_local_password(user_id, password)):
|
||||||
logger.warn("Failed password login for user %s", user_id)
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
def _check_local_password(self, user_id, password):
|
||||||
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
|
return not self.validate_hash(password, password_hash)
|
||||||
|
|
||||||
|
def _check_ldap_password(self, user_id, password):
|
||||||
|
if not self.ldap_enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("Authenticating %s with LDAP" % user_id)
|
||||||
|
try:
|
||||||
|
l = ldap.initialize("%s:%s" % (ldap_server, ldap_port))
|
||||||
|
if self.ldap_tls:
|
||||||
|
logger.debug("Initiating TLS")
|
||||||
|
self._connection.start_tls_s()
|
||||||
|
|
||||||
|
dn = "%s=%s, %s" % (ldap_search_property, user_id.localpart, ldap_search_base)
|
||||||
|
logger.debug("DN for LDAP authentication: %s" % dn)
|
||||||
|
|
||||||
|
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
|
||||||
|
|
||||||
|
if not self.does_user_exist(user_id):
|
||||||
|
user_id, access_token = (
|
||||||
|
yield self.handlers.registration_handler.register(localpart=user_id.localpart)
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ldap.LDAPError, e:
|
||||||
|
logger.info(e)
|
||||||
|
return False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def issue_access_token(self, user_id):
|
def issue_access_token(self, user_id):
|
||||||
access_token = self.generate_access_token(user_id)
|
access_token = self.generate_access_token(user_id)
|
||||||
|
|
|
@ -36,8 +36,6 @@ import xml.etree.ElementTree as ET
|
||||||
import jwt
|
import jwt
|
||||||
from jwt.exceptions import InvalidTokenError
|
from jwt.exceptions import InvalidTokenError
|
||||||
|
|
||||||
import ldap
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -49,7 +47,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
CAS_TYPE = "m.login.cas"
|
CAS_TYPE = "m.login.cas"
|
||||||
TOKEN_TYPE = "m.login.token"
|
TOKEN_TYPE = "m.login.token"
|
||||||
JWT_TYPE = "m.login.jwt"
|
JWT_TYPE = "m.login.jwt"
|
||||||
LDAP_TYPE = "m.login.ldap"
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LoginRestServlet, self).__init__(hs)
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
|
@ -59,13 +56,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
self.jwt_enabled = hs.config.jwt_enabled
|
self.jwt_enabled = hs.config.jwt_enabled
|
||||||
self.jwt_secret = hs.config.jwt_secret
|
self.jwt_secret = hs.config.jwt_secret
|
||||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||||
self.ldap_enabled = hs.config.ldap_enabled
|
|
||||||
self.ldap_server = hs.config.ldap_server
|
|
||||||
self.ldap_port = hs.config.ldap_port
|
|
||||||
self.ldap_search_base = hs.config.ldap_search_base
|
|
||||||
self.ldap_search_property = hs.config.ldap_search_property
|
|
||||||
self.ldap_email_property = hs.config.ldap_email_property
|
|
||||||
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
|
||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.cas_server_url = hs.config.cas_server_url
|
self.cas_server_url = hs.config.cas_server_url
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
|
@ -74,8 +64,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
flows = []
|
flows = []
|
||||||
if self.ldap_enabled:
|
|
||||||
flows.append({"type": LoginRestServlet.LDAP_TYPE})
|
|
||||||
if self.jwt_enabled:
|
if self.jwt_enabled:
|
||||||
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
||||||
if self.saml2_enabled:
|
if self.saml2_enabled:
|
||||||
|
@ -176,49 +164,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def do_ldap_login(self, login_submission):
|
|
||||||
if 'medium' in login_submission and 'address' in login_submission:
|
|
||||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
|
||||||
login_submission['medium'], login_submission['address']
|
|
||||||
)
|
|
||||||
if not user_id:
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
else:
|
|
||||||
user_id = login_submission['user']
|
|
||||||
|
|
||||||
if not user_id.startswith('@'):
|
|
||||||
user_id = UserID.create(
|
|
||||||
user_id, self.hs.hostname
|
|
||||||
).to_string()
|
|
||||||
|
|
||||||
# FIXME check against LDAP Server!!
|
|
||||||
|
|
||||||
auth_handler = self.handlers.auth_handler
|
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
|
||||||
if user_exists:
|
|
||||||
user_id, access_token, refresh_token = (
|
|
||||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
user_id, access_token = (
|
|
||||||
yield self.handlers.registration_handler.register(localpart=user_id.localpart)
|
|
||||||
)
|
|
||||||
result = {
|
|
||||||
"user_id": user_id, # may have changed
|
|
||||||
"access_token": access_token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_token_login(self, login_submission):
|
def do_token_login(self, login_submission):
|
||||||
|
|
Loading…
Reference in a new issue