mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-16 02:53:51 +01:00
Minor review fixes
This commit is contained in:
parent
dd2eb49385
commit
2b779af10f
2 changed files with 14 additions and 17 deletions
|
@ -298,11 +298,11 @@ class AuthHandler(BaseHandler):
|
||||||
defer.returnValue((user_id, access_token, refresh_token))
|
defer.returnValue((user_id, access_token, refresh_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def login_with_user_id(self, user_id):
|
def get_login_tuple_for_user_id(self, user_id):
|
||||||
"""
|
"""
|
||||||
Authenticates the user with the given user ID,
|
Gets login tuple for the user with the given user ID.
|
||||||
it is intended that the authentication of the user has
|
The user is assumed to have been authenticated by some other
|
||||||
already been verified by other mechanism (e.g. CAS)
|
machanism (e.g. CAS)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User ID
|
user_id (str): User ID
|
||||||
|
|
|
@ -146,7 +146,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
yield auth_handler.login_with_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -179,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if user_exists:
|
if user_exists:
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
yield auth_handler.login_with_user_id(user_id)
|
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||||
)
|
)
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
|
@ -304,7 +304,6 @@ class CasRedirectServlet(ClientV1RestServlet):
|
||||||
})
|
})
|
||||||
request.redirect("%s?%s" % (self.cas_server_url, serviceParam))
|
request.redirect("%s?%s" % (self.cas_server_url, serviceParam))
|
||||||
request.finish()
|
request.finish()
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
|
|
||||||
class CasTicketServlet(ClientV1RestServlet):
|
class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
@ -318,21 +317,19 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
clientRedirectUrl = request.args["redirectUrl"][0]
|
client_redirect_url = request.args["redirectUrl"][0]
|
||||||
# TODO: get this from the homeserver rather than creating a new one for
|
http_client = self.hs.get_simple_http_client()
|
||||||
# each request
|
|
||||||
http_client = SimpleHttpClient(self.hs)
|
|
||||||
uri = self.cas_server_url + "/proxyValidate"
|
uri = self.cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
"ticket": request.args["ticket"],
|
"ticket": request.args["ticket"],
|
||||||
"service": self.cas_service_url
|
"service": self.cas_service_url
|
||||||
}
|
}
|
||||||
body = yield http_client.get_raw(uri, args)
|
body = yield http_client.get_raw(uri, args)
|
||||||
result = yield self.handle_cas_response(request, body, clientRedirectUrl)
|
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_cas_response(self, request, cas_response_body, clientRedirectUrl):
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||||
user, attributes = self.parse_cas_response(cas_response_body)
|
user, attributes = self.parse_cas_response(cas_response_body)
|
||||||
|
|
||||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||||
|
@ -351,15 +348,15 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.handlers.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if not user_exists:
|
if not user_exists:
|
||||||
user_id, ignored = (
|
user_id, _ = (
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
login_token = auth_handler.generate_short_term_login_token(user_id)
|
||||||
redirectUrl = self.add_login_token_to_redirect_url(clientRedirectUrl, login_token)
|
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||||
request.redirect(redirectUrl)
|
login_token)
|
||||||
|
request.redirect(redirect_url)
|
||||||
request.finish()
|
request.finish()
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
def add_login_token_to_redirect_url(self, url, token):
|
def add_login_token_to_redirect_url(self, url, token):
|
||||||
url_parts = list(urlparse.urlparse(url))
|
url_parts = list(urlparse.urlparse(url))
|
||||||
|
|
Loading…
Reference in a new issue