forked from MirrorHub/synapse
Code cleanups and simplifications.
Also: share the saml client between redirect and response handlers.
This commit is contained in:
parent
69a43d9974
commit
426049247b
6 changed files with 53 additions and 50 deletions
|
@ -57,7 +57,6 @@ class LoginType(object):
|
||||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||||
MSISDN = u"m.login.msisdn"
|
MSISDN = u"m.login.msisdn"
|
||||||
RECAPTCHA = u"m.login.recaptcha"
|
RECAPTCHA = u"m.login.recaptcha"
|
||||||
SSO = u"m.login.sso"
|
|
||||||
TERMS = u"m.login.terms"
|
TERMS = u"m.login.terms"
|
||||||
DUMMY = u"m.login.dummy"
|
DUMMY = u"m.login.dummy"
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from synapse.python_dependencies import DependencyException, check_requirements
|
||||||
|
|
||||||
from ._base import Config, ConfigError
|
from ._base import Config, ConfigError
|
||||||
|
|
||||||
|
@ -25,6 +26,11 @@ class SAML2Config(Config):
|
||||||
if not saml2_config or not saml2_config.get("enabled", True):
|
if not saml2_config or not saml2_config.get("enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
check_requirements('saml2')
|
||||||
|
except DependencyException as e:
|
||||||
|
raise ConfigError(e.message)
|
||||||
|
|
||||||
self.saml2_enabled = True
|
self.saml2_enabled = True
|
||||||
|
|
||||||
import saml2.config
|
import saml2.config
|
||||||
|
@ -75,7 +81,6 @@ class SAML2Config(Config):
|
||||||
# override them.
|
# override them.
|
||||||
#
|
#
|
||||||
#saml2_config:
|
#saml2_config:
|
||||||
# enabled: true
|
|
||||||
# sp_config:
|
# sp_config:
|
||||||
# # point this to the IdP's metadata. You can use either a local file or
|
# # point this to the IdP's metadata. You can use either a local file or
|
||||||
# # (preferably) a URL.
|
# # (preferably) a URL.
|
||||||
|
|
|
@ -767,9 +767,6 @@ class AuthHandler(BaseHandler):
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
defer.returnValue((canonical_user_id, None))
|
defer.returnValue((canonical_user_id, None))
|
||||||
|
|
||||||
if login_type == LoginType.SSO:
|
|
||||||
known_login_type = True
|
|
||||||
|
|
||||||
if not known_login_type:
|
if not known_login_type:
|
||||||
raise SynapseError(400, "Unknown login type %s" % login_type)
|
raise SynapseError(400, "Unknown login type %s" % login_type)
|
||||||
|
|
||||||
|
|
|
@ -34,10 +34,6 @@ from synapse.rest.well_known import WellKnownBuilder
|
||||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
import saml2
|
|
||||||
from saml2.client import Saml2Client
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -378,28 +374,49 @@ class LoginRestServlet(RestServlet):
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
|
||||||
class CasRedirectServlet(RestServlet):
|
class BaseSsoRedirectServlet(RestServlet):
|
||||||
|
"""Common base class for /login/sso/redirect impls"""
|
||||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||||
|
|
||||||
|
def on_GET(self, request):
|
||||||
|
args = request.args
|
||||||
|
if b"redirectUrl" not in args:
|
||||||
|
return 400, "Redirect URL not specified for SSO auth"
|
||||||
|
client_redirect_url = args[b"redirectUrl"][0]
|
||||||
|
sso_url = self.get_sso_url(client_redirect_url)
|
||||||
|
request.redirect(sso_url)
|
||||||
|
finish_request(request)
|
||||||
|
|
||||||
|
def get_sso_url(self, client_redirect_url):
|
||||||
|
"""Get the URL to redirect to, to perform SSO auth
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_redirect_url (bytes): the URL that we should redirect the
|
||||||
|
client to when everything is done
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: URL to redirect to
|
||||||
|
"""
|
||||||
|
# to be implemented by subclasses
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class CasRedirectServlet(RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CasRedirectServlet, self).__init__()
|
super(CasRedirectServlet, self).__init__()
|
||||||
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||||
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||||
|
|
||||||
def on_GET(self, request):
|
def get_sso_url(self, client_redirect_url):
|
||||||
args = request.args
|
|
||||||
if b"redirectUrl" not in args:
|
|
||||||
return (400, "Redirect URL not specified for CAS auth")
|
|
||||||
client_redirect_url_param = urllib.parse.urlencode({
|
client_redirect_url_param = urllib.parse.urlencode({
|
||||||
b"redirectUrl": args[b"redirectUrl"][0]
|
b"redirectUrl": client_redirect_url
|
||||||
}).encode('ascii')
|
}).encode('ascii')
|
||||||
hs_redirect_url = (self.cas_service_url +
|
hs_redirect_url = (self.cas_service_url +
|
||||||
b"/_matrix/client/r0/login/cas/ticket")
|
b"/_matrix/client/r0/login/cas/ticket")
|
||||||
service_param = urllib.parse.urlencode({
|
service_param = urllib.parse.urlencode({
|
||||||
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||||
}).encode('ascii')
|
}).encode('ascii')
|
||||||
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
|
return b"%s/login?%s" % (self.cas_server_url, service_param)
|
||||||
finish_request(request)
|
|
||||||
|
|
||||||
|
|
||||||
class CasTicketServlet(RestServlet):
|
class CasTicketServlet(RestServlet):
|
||||||
|
@ -482,41 +499,23 @@ class CasTicketServlet(RestServlet):
|
||||||
return user, attributes
|
return user, attributes
|
||||||
|
|
||||||
|
|
||||||
class SSORedirectServlet(RestServlet):
|
class SAMLRedirectServlet(BaseSsoRedirectServlet):
|
||||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(SSORedirectServlet, self).__init__()
|
self._saml_client = hs.get_saml_client()
|
||||||
self.saml2_sp_config = hs.config.saml2_sp_config
|
|
||||||
|
|
||||||
def on_GET(self, request):
|
def get_sso_url(self, client_redirect_url):
|
||||||
args = request.args
|
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||||
|
relay_state=client_redirect_url,
|
||||||
|
)
|
||||||
|
|
||||||
saml_client = Saml2Client(self.saml2_sp_config)
|
|
||||||
reqid, info = saml_client.prepare_for_authenticate()
|
|
||||||
|
|
||||||
redirect_url = None
|
|
||||||
|
|
||||||
# Select the IdP URL to send the AuthN request to
|
|
||||||
for key, value in info['headers']:
|
for key, value in info['headers']:
|
||||||
if key is 'Location':
|
if key == 'Location':
|
||||||
redirect_url = value
|
return value
|
||||||
|
|
||||||
if redirect_url is None:
|
# this shouldn't happen!
|
||||||
raise LoginError(401, "Unsuccessful SSO SAML2 redirect url response",
|
raise Exception("prepare_for_authenticate didn't return a Location header")
|
||||||
errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
relay_state = "/_matrix/client/r0/login"
|
|
||||||
if b"redirectUrl" in args:
|
|
||||||
relay_state = args[b"redirectUrl"][0]
|
|
||||||
|
|
||||||
url_parts = list(urllib.parse.urlparse(redirect_url))
|
|
||||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
|
||||||
query.update({"RelayState": relay_state})
|
|
||||||
url_parts[4] = urllib.parse.urlencode(query)
|
|
||||||
|
|
||||||
request.redirect(urllib.parse.urlunparse(url_parts))
|
|
||||||
finish_request(request)
|
|
||||||
|
|
||||||
|
|
||||||
class SSOAuthHandler(object):
|
class SSOAuthHandler(object):
|
||||||
|
@ -594,5 +593,5 @@ def register_servlets(hs, http_server):
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
CasRedirectServlet(hs).register(http_server)
|
CasRedirectServlet(hs).register(http_server)
|
||||||
CasTicketServlet(hs).register(http_server)
|
CasTicketServlet(hs).register(http_server)
|
||||||
if hs.config.saml2_enabled:
|
elif hs.config.saml2_enabled:
|
||||||
SSORedirectServlet(hs).register(http_server)
|
SAMLRedirectServlet(hs).register(http_server)
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import saml2
|
import saml2
|
||||||
from saml2.client import Saml2Client
|
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
@ -36,8 +35,7 @@ class SAML2ResponseResource(Resource):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self)
|
Resource.__init__(self)
|
||||||
|
self._saml_client = hs.get_saml_client()
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
|
||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
|
|
||||||
def render_POST(self, request):
|
def render_POST(self, request):
|
||||||
|
|
|
@ -189,6 +189,7 @@ class HomeServer(object):
|
||||||
'registration_handler',
|
'registration_handler',
|
||||||
'account_validity_handler',
|
'account_validity_handler',
|
||||||
'event_client_serializer',
|
'event_client_serializer',
|
||||||
|
'saml_client',
|
||||||
]
|
]
|
||||||
|
|
||||||
REQUIRED_ON_MASTER_STARTUP = [
|
REQUIRED_ON_MASTER_STARTUP = [
|
||||||
|
@ -522,6 +523,10 @@ class HomeServer(object):
|
||||||
def build_event_client_serializer(self):
|
def build_event_client_serializer(self):
|
||||||
return EventClientSerializer(self)
|
return EventClientSerializer(self)
|
||||||
|
|
||||||
|
def build_saml_client(self):
|
||||||
|
from saml2.client import Saml2Client
|
||||||
|
return Saml2Client(self.config.saml2_sp_config)
|
||||||
|
|
||||||
def remove_pusher(self, app_id, push_key, user_id):
|
def remove_pusher(self, app_id, push_key, user_id):
|
||||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue