forked from MirrorHub/synapse
Support multiple required attributes in CAS response, and in a nicer config format too
This commit is contained in:
parent
76421c496d
commit
01a5f1991c
2 changed files with 10 additions and 22 deletions
|
@ -27,28 +27,17 @@ class CasConfig(Config):
|
|||
if cas_config:
|
||||
self.cas_enabled = True
|
||||
self.cas_server_url = cas_config["server_url"]
|
||||
|
||||
if "required_attribute" in cas_config:
|
||||
self.cas_required_attribute = cas_config["required_attribute"]
|
||||
else:
|
||||
self.cas_required_attribute = None
|
||||
|
||||
if "required_attribute_value" in cas_config:
|
||||
self.cas_required_attribute_value = cas_config["required_attribute_value"]
|
||||
else:
|
||||
self.cas_required_attribute_value = None
|
||||
|
||||
self.cas_required_attributes = cas_config.get("required_attributes", None)
|
||||
else:
|
||||
self.cas_enabled = False
|
||||
self.cas_server_url = None
|
||||
self.cas_required_attribute = None
|
||||
self.cas_required_attribute_value = None
|
||||
self.cas_required_attributes = {}
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable CAS for registration and login.
|
||||
#cas_config:
|
||||
# server_url: "https://cas-server.com"
|
||||
# #required_attribute: something
|
||||
# #required_attribute_value: true
|
||||
# #required_attributes:
|
||||
# # name: value
|
||||
"""
|
||||
|
|
|
@ -46,8 +46,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.saml2_enabled = hs.config.saml2_enabled
|
||||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_required_attribute = hs.config.cas_required_attribute
|
||||
self.cas_required_attribute_value = hs.config.cas_required_attribute_value
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
|
||||
def on_GET(self, request):
|
||||
|
@ -128,16 +127,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
def do_cas_login(self, cas_response_body):
|
||||
(user, attributes) = self.parse_cas_response(cas_response_body)
|
||||
|
||||
if self.cas_required_attribute is not None:
|
||||
for required_attribute in self.cas_required_attributes:
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if self.cas_required_attribute not in attributes:
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if self.cas_required_attribute_value is not None:
|
||||
actualValue = attributes[self.cas_required_attribute]
|
||||
if self.cas_required_attributes[required_attribute] is not None:
|
||||
actualValue = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if self.cas_required_attribute_value != actualValue:
|
||||
if self.cas_required_attributes[required_attribute] != actualValue:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
|
|
Loading…
Reference in a new issue