0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 00:21:56 +01:00

Allow hs to do CAS login completely and issue the client with a login token that can be redeemed for the usual successful login response

This commit is contained in:
Steven Hammerton 2015-11-05 14:01:12 +00:00
parent 45f1827fb7
commit 414a4a71b4
3 changed files with 218 additions and 5 deletions

View file

@ -41,7 +41,7 @@ class CasConfig(Config):
#cas_config: #cas_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# ticket_redirect_url: "https://homesever.domain.com:8448" # service_url: "https://homesever.domain.com:8448"
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -46,6 +46,13 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"type = ",
"time < ",
"user_id = ",
])
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -297,10 +304,11 @@ class AuthHandler(BaseHandler):
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def login_with_cas_user_id(self, user_id): def login_with_user_id(self, user_id):
""" """
Authenticates the user with the given user ID, Authenticates the user with the given user ID,
intended to have been captured from a CAS response it is intended that the authentication of the user has
already been verified by other mechanism (e.g. CAS)
Args: Args:
user_id (str): User ID user_id (str): User ID
@ -393,6 +401,17 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
return self._validate_macaroon_and_get_user_id(login_token, "login", True)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
@ -402,6 +421,57 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def _validate_macaroon_and_get_user_id(self, macaroon_str,
macaroon_type, validate_expiry):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_id = self._get_user_from_macaroon(macaroon)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + macaroon_type)
v.satisfy_exact("user_id = " + user_id)
if validate_expiry:
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
return user_id
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "Invalid token",
errcode=Codes.UNKNOWN_TOKEN
)
def _get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)

View file

@ -22,6 +22,7 @@ from base import ClientV1RestServlet, client_path_pattern
import simplejson as json import simplejson as json
import urllib import urllib
import urlparse
import logging import logging
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST
@ -39,6 +40,7 @@ class LoginRestServlet(ClientV1RestServlet):
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
@ -58,6 +60,7 @@ class LoginRestServlet(ClientV1RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.password_enabled: if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE}) flows.append({"type": LoginRestServlet.PASS_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -83,6 +86,7 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] == elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE): LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
@ -96,6 +100,9 @@ class LoginRestServlet(ClientV1RestServlet):
body = yield http_client.get_raw(uri, args) body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body) result = yield self.do_cas_login(body)
defer.returnValue(result) defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -131,6 +138,26 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.login_with_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
@ -152,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id) yield auth_handler.login_with_user_id(user_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -173,6 +200,7 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
@ -243,6 +271,7 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERN = client_path_pattern("/login/cas")
@ -254,6 +283,118 @@ class CasRestServlet(ClientV1RestServlet):
return (200, {"serverUrl": self.cas_server_url}) return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/redirect")
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
clientRedirectUrlParam = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hsRedirectUrl = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
serviceParam = urllib.urlencode({
"service": "%s?%s" % (hsRedirectUrl, clientRedirectUrlParam)
})
request.redirect("%s?%s" % (self.cas_server_url, serviceParam))
request.finish()
defer.returnValue(None)
class CasTicketServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/ticket")
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
@defer.inlineCallbacks
def on_GET(self, request):
clientRedirectUrl = request.args["redirectUrl"][0]
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
body = yield http_client.get_raw(uri, args)
result = yield self.handle_cas_response(request, body, clientRedirectUrl)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, clientRedirectUrl):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, ignored = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
redirectUrl = self.add_login_token_to_redirect_url(clientRedirectUrl, login_token)
request.redirect(redirectUrl)
request.finish()
defer.returnValue(None)
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -269,5 +410,7 @@ def register_servlets(hs, http_server):
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server) SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server) CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)