Code cleanups and simplifications.

Also: share the saml client between redirect and response handlers.
This commit is contained in:
Richard van der Hoff 2019-06-11 00:03:57 +01:00
parent 69a43d9974
commit 426049247b
6 changed files with 53 additions and 50 deletions

View file

@ -57,7 +57,6 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
MSISDN = u"m.login.msisdn" MSISDN = u"m.login.msisdn"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
SSO = u"m.login.sso"
TERMS = u"m.login.terms" TERMS = u"m.login.terms"
DUMMY = u"m.login.dummy" DUMMY = u"m.login.dummy"

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.python_dependencies import DependencyException, check_requirements
from ._base import Config, ConfigError from ._base import Config, ConfigError
@ -25,6 +26,11 @@ class SAML2Config(Config):
if not saml2_config or not saml2_config.get("enabled", True): if not saml2_config or not saml2_config.get("enabled", True):
return return
try:
check_requirements('saml2')
except DependencyException as e:
raise ConfigError(e.message)
self.saml2_enabled = True self.saml2_enabled = True
import saml2.config import saml2.config
@ -75,7 +81,6 @@ class SAML2Config(Config):
# override them. # override them.
# #
#saml2_config: #saml2_config:
# enabled: true
# sp_config: # sp_config:
# # point this to the IdP's metadata. You can use either a local file or # # point this to the IdP's metadata. You can use either a local file or
# # (preferably) a URL. # # (preferably) a URL.

View file

@ -767,9 +767,6 @@ class AuthHandler(BaseHandler):
if canonical_user_id: if canonical_user_id:
defer.returnValue((canonical_user_id, None)) defer.returnValue((canonical_user_id, None))
if login_type == LoginType.SSO:
known_login_type = True
if not known_login_type: if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type) raise SynapseError(400, "Unknown login type %s" % login_type)

View file

@ -34,10 +34,6 @@ from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
import saml2
from saml2.client import Saml2Client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -378,28 +374,49 @@ class LoginRestServlet(RestServlet):
defer.returnValue(result) defer.returnValue(result)
class CasRedirectServlet(RestServlet): class BaseSsoRedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
sso_url = self.get_sso_url(client_redirect_url)
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url):
"""Get the URL to redirect to, to perform SSO auth
Args:
client_redirect_url (bytes): the URL that we should redirect the
client to when everything is done
Returns:
bytes: URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
class CasRedirectServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__() super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode('ascii') self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii') self.cas_service_url = hs.config.cas_service_url.encode('ascii')
def on_GET(self, request): def get_sso_url(self, client_redirect_url):
args = request.args
if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.parse.urlencode({ client_redirect_url_param = urllib.parse.urlencode({
b"redirectUrl": args[b"redirectUrl"][0] b"redirectUrl": client_redirect_url
}).encode('ascii') }).encode('ascii')
hs_redirect_url = (self.cas_service_url + hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/r0/login/cas/ticket") b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({ service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii') }).encode('ascii')
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) return b"%s/login?%s" % (self.cas_server_url, service_param)
finish_request(request)
class CasTicketServlet(RestServlet): class CasTicketServlet(RestServlet):
@ -482,41 +499,23 @@ class CasTicketServlet(RestServlet):
return user, attributes return user, attributes
class SSORedirectServlet(RestServlet): class SAMLRedirectServlet(BaseSsoRedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True) PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(SSORedirectServlet, self).__init__() self._saml_client = hs.get_saml_client()
self.saml2_sp_config = hs.config.saml2_sp_config
def on_GET(self, request): def get_sso_url(self, client_redirect_url):
args = request.args reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url,
)
saml_client = Saml2Client(self.saml2_sp_config)
reqid, info = saml_client.prepare_for_authenticate()
redirect_url = None
# Select the IdP URL to send the AuthN request to
for key, value in info['headers']: for key, value in info['headers']:
if key is 'Location': if key == 'Location':
redirect_url = value return value
if redirect_url is None: # this shouldn't happen!
raise LoginError(401, "Unsuccessful SSO SAML2 redirect url response", raise Exception("prepare_for_authenticate didn't return a Location header")
errcode=Codes.UNAUTHORIZED)
relay_state = "/_matrix/client/r0/login"
if b"redirectUrl" in args:
relay_state = args[b"redirectUrl"][0]
url_parts = list(urllib.parse.urlparse(redirect_url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"RelayState": relay_state})
url_parts[4] = urllib.parse.urlencode(query)
request.redirect(urllib.parse.urlunparse(url_parts))
finish_request(request)
class SSOAuthHandler(object): class SSOAuthHandler(object):
@ -594,5 +593,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server) CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
if hs.config.saml2_enabled: elif hs.config.saml2_enabled:
SSORedirectServlet(hs).register(http_server) SAMLRedirectServlet(hs).register(http_server)

View file

@ -16,7 +16,6 @@
import logging import logging
import saml2 import saml2
from saml2.client import Saml2Client
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -36,8 +35,7 @@ class SAML2ResponseResource(Resource):
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) Resource.__init__(self)
self._saml_client = hs.get_saml_client()
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
def render_POST(self, request): def render_POST(self, request):

View file

@ -189,6 +189,7 @@ class HomeServer(object):
'registration_handler', 'registration_handler',
'account_validity_handler', 'account_validity_handler',
'event_client_serializer', 'event_client_serializer',
'saml_client',
] ]
REQUIRED_ON_MASTER_STARTUP = [ REQUIRED_ON_MASTER_STARTUP = [
@ -522,6 +523,10 @@ class HomeServer(object):
def build_event_client_serializer(self): def build_event_client_serializer(self):
return EventClientSerializer(self) return EventClientSerializer(self)
def build_saml_client(self):
from saml2.client import Saml2Client
return Saml2Client(self.config.saml2_sp_config)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)