forked from MirrorHub/synapse
Share more code between macaroon validation
This commit is contained in:
parent
0b31223c7a
commit
dd2eb49385
2 changed files with 17 additions and 57 deletions
|
@ -587,7 +587,10 @@ class Auth(object):
|
||||||
def _get_user_from_macaroon(self, macaroon_str):
|
def _get_user_from_macaroon(self, macaroon_str):
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||||
self._validate_macaroon(macaroon)
|
self.validate_macaroon(
|
||||||
|
macaroon, "access",
|
||||||
|
[lambda c: c == "guest = true", lambda c: c.startswith("time < ")]
|
||||||
|
)
|
||||||
|
|
||||||
user_prefix = "user_id = "
|
user_prefix = "user_id = "
|
||||||
user = None
|
user = None
|
||||||
|
@ -635,26 +638,24 @@ class Auth(object):
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_macaroon(self, macaroon):
|
def validate_macaroon(self, macaroon, type_string, additional_validation_functions):
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_exact("gen = 1")
|
v.satisfy_exact("gen = 1")
|
||||||
v.satisfy_exact("type = access")
|
v.satisfy_exact("type = " + type_string)
|
||||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||||
v.satisfy_general(self._verify_expiry)
|
|
||||||
v.satisfy_exact("guest = true")
|
for validation_function in additional_validation_functions:
|
||||||
|
v.satisfy_general(validation_function)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_general(self._verify_recognizes_caveats)
|
v.satisfy_general(self._verify_recognizes_caveats)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
def verify_expiry(self, caveat):
|
||||||
prefix = "time < "
|
prefix = "time < "
|
||||||
if not caveat.startswith(prefix):
|
if not caveat.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
# TODO(daniel): Enable expiry check when clients actually know how to
|
|
||||||
# refresh tokens. (And remember to enable the tests)
|
|
||||||
return True
|
|
||||||
expiry = int(caveat[len(prefix):])
|
expiry = int(caveat[len(prefix):])
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
|
@ -47,12 +47,6 @@ class AuthHandler(BaseHandler):
|
||||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
|
||||||
"gen = ",
|
|
||||||
"type = ",
|
|
||||||
"time < ",
|
|
||||||
"user_id = ",
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -410,7 +404,13 @@ class AuthHandler(BaseHandler):
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||||
return self._validate_macaroon_and_get_user_id(login_token, "login", True)
|
try:
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||||
|
auth_api = self.hs.get_auth()
|
||||||
|
auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry])
|
||||||
|
return self._get_user_from_macaroon(macaroon)
|
||||||
|
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||||
|
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
|
||||||
|
|
||||||
def _generate_base_macaroon(self, user_id):
|
def _generate_base_macaroon(self, user_id):
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
@ -421,30 +421,6 @@ class AuthHandler(BaseHandler):
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
return macaroon
|
return macaroon
|
||||||
|
|
||||||
def _validate_macaroon_and_get_user_id(self, macaroon_str,
|
|
||||||
macaroon_type, validate_expiry):
|
|
||||||
try:
|
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
|
||||||
user_id = self._get_user_from_macaroon(macaroon)
|
|
||||||
v = pymacaroons.Verifier()
|
|
||||||
v.satisfy_exact("gen = 1")
|
|
||||||
v.satisfy_exact("type = " + macaroon_type)
|
|
||||||
v.satisfy_exact("user_id = " + user_id)
|
|
||||||
if validate_expiry:
|
|
||||||
v.satisfy_general(self._verify_expiry)
|
|
||||||
|
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
|
||||||
v.satisfy_general(self._verify_recognizes_caveats)
|
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
|
||||||
return user_id
|
|
||||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
|
||||||
raise AuthError(
|
|
||||||
self.INVALID_TOKEN_HTTP_STATUS, "Invalid token",
|
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_user_from_macaroon(self, macaroon):
|
def _get_user_from_macaroon(self, macaroon):
|
||||||
user_prefix = "user_id = "
|
user_prefix = "user_id = "
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
|
@ -455,23 +431,6 @@ class AuthHandler(BaseHandler):
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
|
||||||
prefix = "time < "
|
|
||||||
if not caveat.startswith(prefix):
|
|
||||||
return False
|
|
||||||
expiry = int(caveat[len(prefix):])
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
return now < expiry
|
|
||||||
|
|
||||||
def _verify_recognizes_caveats(self, caveat):
|
|
||||||
first_space = caveat.find(" ")
|
|
||||||
if first_space < 0:
|
|
||||||
return False
|
|
||||||
second_space = caveat.find(" ", first_space + 1)
|
|
||||||
if second_space < 0:
|
|
||||||
return False
|
|
||||||
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword):
|
def set_password(self, user_id, newpassword):
|
||||||
password_hash = self.hash(newpassword)
|
password_hash = self.hash(newpassword)
|
||||||
|
|
Loading…
Add table
Reference in a new issue