forked from MirrorHub/synapse
Merge pull request #918 from negzi/bugfix_for_token_expiry
Bug fix: expire invalid access tokens
This commit is contained in:
commit
209e04fa11
6 changed files with 42 additions and 9 deletions
|
@ -629,7 +629,10 @@ class Auth(object):
|
||||||
except AuthError:
|
except AuthError:
|
||||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||||
# have been re-issued as macaroons.
|
# have been re-issued as macaroons.
|
||||||
|
if self.hs.config.expire_access_token:
|
||||||
|
raise
|
||||||
ret = yield self._look_up_user_by_access_token(token)
|
ret = yield self._look_up_user_by_access_token(token)
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -637,12 +637,13 @@ class AuthHandler(BaseHandler):
|
||||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||||
defer.returnValue(refresh_token)
|
defer.returnValue(refresh_token)
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None):
|
def generate_access_token(self, user_id, extra_caveats=None,
|
||||||
|
duration_in_ms=(60 * 60 * 1000)):
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + (60 * 60 * 1000)
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
for caveat in extra_caveats:
|
for caveat in extra_caveats:
|
||||||
macaroon.add_first_party_caveat(caveat)
|
macaroon.add_first_party_caveat(caveat)
|
||||||
|
|
|
@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, localpart, displayname, duration_seconds,
|
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
||||||
password_hash=None):
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
token = self.auth_handler().generate_short_term_login_token(
|
token = self.auth_handler().generate_access_token(
|
||||||
user_id, duration_seconds)
|
user_id, None, duration_in_ms)
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
|
|
|
@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
user_id, token = yield handler.get_or_create_user(
|
user_id, token = yield handler.get_or_create_user(
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_seconds=duration_seconds,
|
duration_in_ms=(duration_seconds * 1000),
|
||||||
password_hash=password_hash
|
password_hash=password_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
macaroon.add_first_party_caveat("user_id = %s" % (user,))
|
||||||
macaroon.add_first_party_caveat("time < 1") # ms
|
macaroon.add_first_party_caveat("time < -2000") # ms
|
||||||
|
|
||||||
self.hs.clock.now = 5000 # seconds
|
self.hs.clock.now = 5000 # seconds
|
||||||
self.hs.config.expire_access_token = True
|
self.hs.config.expire_access_token = True
|
||||||
|
@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase):
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_user_from_macaroon_with_valid_duration(self):
|
||||||
|
# TODO(danielwh): Remove this mock when we remove the
|
||||||
|
# get_user_by_access_token fallback.
|
||||||
|
self.store.get_user_by_access_token = Mock(
|
||||||
|
return_value={"name": "@baldrick:matrix.org"}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.store.get_user_by_access_token = Mock(
|
||||||
|
return_value={"name": "@baldrick:matrix.org"}
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = "@baldrick:matrix.org"
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location=self.hs.config.server_name,
|
||||||
|
identifier="key",
|
||||||
|
key=self.hs.config.macaroon_secret_key)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
|
macaroon.add_first_party_caveat("time < 900000000") # ms
|
||||||
|
|
||||||
|
self.hs.clock.now = 5000 # seconds
|
||||||
|
self.hs.config.expire_access_token = True
|
||||||
|
|
||||||
|
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
|
user = user_info["user"]
|
||||||
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
|
|
|
@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
http_client=None,
|
http_client=None,
|
||||||
expire_access_token=True)
|
expire_access_token=True)
|
||||||
self.auth_handler = Mock(
|
self.auth_handler = Mock(
|
||||||
generate_short_term_login_token=Mock(return_value='secret'))
|
generate_access_token=Mock(return_value='secret'))
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
self.hs.get_handlers().profile_handler = Mock()
|
self.hs.get_handlers().profile_handler = Mock()
|
||||||
self.mock_handler = Mock(spec=[
|
self.mock_handler = Mock(spec=[
|
||||||
"generate_short_term_login_token",
|
"generate_access_token",
|
||||||
])
|
])
|
||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue