Merge pull request #918 from negzi/bugfix_for_token_expiry

Bug fix: expire invalid access tokens
This commit is contained in:
Erik Johnston 2016-07-14 15:51:52 +01:00 committed by GitHub
commit 209e04fa11
6 changed files with 42 additions and 9 deletions

View file

@ -629,7 +629,10 @@ class Auth(object):
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -637,12 +637,13 @@ class AuthHandler(BaseHandler):
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)

View file

@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds, def get_or_create_user(self, localpart, displayname, duration_in_ms,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_short_term_login_token( token = self.auth_handler().generate_access_token(
user_id, duration_seconds) user_id, None, duration_in_ms)
if need_register: if need_register:
yield self.store.register( yield self.store.register(

View file

@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_seconds=duration_seconds, duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash password_hash=password_hash
) )

View file

@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("time < 1") # ms macaroon.add_first_party_caveat("time < -2000") # ms
self.hs.clock.now = 5000 # seconds self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True self.hs.config.expire_access_token = True
@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg) self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_with_valid_duration(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("time < 900000000") # ms
self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)

View file

@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase):
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True)
self.auth_handler = Mock( self.auth_handler = Mock(
generate_short_term_login_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock() self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[ self.mock_handler = Mock(spec=[
"generate_short_term_login_token", "generate_access_token",
]) ])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)