Merge pull request #349 from stevenhammerton/sh-cas-auth-via-homeserver

SH CAS auth via homeserver
This commit is contained in:
Erik Johnston 2015-11-17 14:36:15 +00:00
commit e503848990
5 changed files with 199 additions and 39 deletions

View file

@ -587,7 +587,10 @@ class Auth(object):
def _get_user_from_macaroon(self, macaroon_str): def _get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon) self.validate_macaroon(
macaroon, "access",
[lambda c: c.startswith("time < ")]
)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None
@ -635,26 +638,25 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def _validate_macaroon(self, macaroon): def validate_macaroon(self, macaroon, type_string, additional_validation_functions):
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access") v.satisfy_exact("type = " + type_string)
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
for validation_function in additional_validation_functions:
v.satisfy_general(validation_function)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats) v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat): def verify_expiry(self, caveat):
prefix = "time < " prefix = "time < "
if not caveat.startswith(prefix): if not caveat.startswith(prefix):
return False return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):]) expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry

View file

@ -27,10 +27,12 @@ class CasConfig(Config):
if cas_config: if cas_config:
self.cas_enabled = cas_config.get("enabled", True) self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {}) self.cas_required_attributes = cas_config.get("required_attributes", {})
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_service_url = None
self.cas_required_attributes = {} self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
@ -39,6 +41,7 @@ class CasConfig(Config):
#cas_config: #cas_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homesever.domain.com:8448"
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -46,6 +46,7 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -297,10 +298,11 @@ class AuthHandler(BaseHandler):
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def login_with_cas_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Authenticates the user with the given user ID, Gets login tuple for the user with the given user ID.
intended to have been captured from a CAS response The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args: Args:
user_id (str): User ID user_id (str): User ID
@ -393,6 +395,23 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry])
return self._get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
@ -402,6 +421,16 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def _get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)

View file

@ -22,6 +22,7 @@ from base import ClientV1RestServlet, client_path_pattern
import simplejson as json import simplejson as json
import urllib import urllib
import urlparse
import logging import logging
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST
@ -39,6 +40,7 @@ class LoginRestServlet(ClientV1RestServlet):
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
@ -58,6 +60,7 @@ class LoginRestServlet(ClientV1RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.password_enabled: if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE}) flows.append({"type": LoginRestServlet.PASS_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -83,6 +86,7 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] == elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE): LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
@ -96,6 +100,9 @@ class LoginRestServlet(ClientV1RestServlet):
body = yield http_client.get_raw(uri, args) body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body) result = yield self.do_cas_login(body)
defer.returnValue(result) defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -131,6 +138,26 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
@ -152,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(user_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -173,6 +200,7 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
@ -243,6 +271,7 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERN = client_path_pattern("/login/cas")
@ -254,6 +283,115 @@ class CasRestServlet(ClientV1RestServlet):
return (200, {"serverUrl": self.cas_server_url}) return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/redirect")
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
request.redirect("%s?%s" % (self.cas_server_url, service_param))
request.finish()
class CasTicketServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/ticket")
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
body = yield http_client.get_raw(uri, args)
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
request.finish()
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -269,5 +407,7 @@ def register_servlets(hs, http_server):
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server) SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server) CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -17,12 +17,11 @@ var submitPassword = function(user, pwd) {
}).error(errorFunc); }).error(errorFunc);
}; };
var submitCas = function(ticket, service) { var submitToken = function(loginToken) {
console.log("Logging in with cas..."); console.log("Logging in with login token...");
var data = { var data = {
type: "m.login.cas", type: "m.login.token",
ticket: ticket, token: loginToken
service: service,
}; };
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login(); show_login();
@ -41,23 +40,10 @@ var errorFunc = function(err) {
} }
}; };
var getCasURL = function(cb) {
$.get(matrixLogin.endpoint + "/cas", function(response) {
var cas_url = response.serverUrl;
cb(cas_url);
}).error(errorFunc);
};
var gotoCas = function() { var gotoCas = function() {
getCasURL(function(cas_url) { var this_page = window.location.origin + window.location.pathname;
var this_page = window.location.origin + window.location.pathname; var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page);
window.location.replace(redirect_url);
var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page);
window.location.replace(redirect_url);
});
} }
var setFeedbackString = function(text) { var setFeedbackString = function(text) {
@ -111,7 +97,7 @@ var fetch_info = function(cb) {
matrixLogin.onLoad = function() { matrixLogin.onLoad = function() {
fetch_info(function() { fetch_info(function() {
if (!try_cas()) { if (!try_token()) {
show_login(); show_login();
} }
}); });
@ -148,20 +134,20 @@ var parseQsFromUrl = function(query) {
return result; return result;
}; };
var try_cas = function() { var try_token = function() {
var pos = window.location.href.indexOf("?"); var pos = window.location.href.indexOf("?");
if (pos == -1) { if (pos == -1) {
return false; return false;
} }
var qs = parseQsFromUrl(window.location.href.substr(pos+1)); var qs = parseQsFromUrl(window.location.href.substr(pos+1));
var ticket = qs.ticket; var loginToken = qs.loginToken;
if (!ticket) { if (!loginToken) {
return false; return false;
} }
submitCas(ticket, location.origin); submitToken(loginToken);
return true; return true;
}; };