0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-13 16:18:56 +02:00

Provide ability to login using CAS

This commit is contained in:
Steven Hammerton 2015-10-07 14:45:57 +01:00
parent ce19fc0f11
commit c33f5c1a24
4 changed files with 135 additions and 2 deletions

39
synapse/config/cas.py Normal file
View file

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class CasConfig(Config):
"""Cas Configuration
cas_server_url: URL of CAS server
"""
def read_config(self, config):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = True
self.cas_server_url = cas_config["server_url"]
else:
self.cas_enabled = False
self.cas_server_url = None
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# server_url: "https://cas-server.com"
"""

View file

@ -26,12 +26,13 @@ from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
from .saml2 import SAML2Config
from .cas import CasConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ):
AppServiceConfig, KeyConfig, SAML2Config, CasConfig):
pass

View file

@ -295,6 +295,37 @@ class AuthHandler(BaseHandler):
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def login_with_cas_user_id(self, user_id):
"""
Authenticates the user with the given user ID, intended to have been captured from a CAS response
Args:
user_id (str): User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case

View file

@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
@ -27,6 +27,9 @@ from saml2 import BINDING_HTTP_POST
from saml2 import config
from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
import requests
logger = logging.getLogger(__name__)
@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.servername = hs.config.server_name
def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
return (200, {"flows": flows})
def on_OPTIONS(self, request):
@ -67,6 +77,12 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE):
url = "%s/proxyValidate" % (self.cas_server_url)
parameters = {"ticket": login_submission["ticket"], "service": login_submission["service"]}
response = requests.get(url, verify=False, params=parameters)
result = yield self.do_cas_login(response.text)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
@ -100,6 +116,41 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
user_id = "@%s:%s" % (user, self.servername)
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = yield auth_handler.login_with_cas_user_id(user_id)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = yield self.handlers.registration_handler.register(localpart=user)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
@ -173,6 +224,15 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"}))
class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
def _parse_json(request):
try:
@ -188,4 +248,6 @@ def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)