0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 14:23:50 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.19.0

This commit is contained in:
Erik Johnston 2017-02-02 16:50:28 +00:00
commit f8c407a13b
16 changed files with 135 additions and 107 deletions

View file

@ -65,6 +65,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -529,37 +530,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, yield self.store.add_access_token_to_user(user_id, access_token,
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
try: try:
@ -570,15 +545,6 @@ class AuthHandler(BaseHandler):
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
@ -673,6 +639,48 @@ class AuthHandler(BaseHandler):
return False return False
class MacaroonGeneartor(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.macaroon_secret_key = hs.config.macaroon_secret_key
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.server_name,
identifier="key",
key=self.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class _AccountHandler(object): class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they """A proxy object that gets passed to password auth providers so they
can register new users etc if necessary. can register new users etc if necessary.

View file

@ -17,7 +17,7 @@ from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
@ -198,20 +198,22 @@ class DeviceHandler(BaseHandler):
"""Notify that a user's device(s) has changed. Pokes the notifier, and """Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local. remote servers if the user is local.
""" """
rooms = yield self.store.get_rooms_for_user(user_id) users_who_share_room = yield self.store.get_users_who_share_room_with_user(
room_ids = [r.room_id for r in rooms] user_id
)
hosts = set() hosts = set()
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
for room_id in room_ids: hosts.update(get_domain_from_id(u) for u in users_who_share_room)
users = yield self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name) hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams( position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts) user_id, device_ids, list(hosts)
) )
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids, "device_list_key", position, rooms=room_ids,
) )
@ -243,15 +245,15 @@ class DeviceHandler(BaseHandler):
possibly_changed = set(changed) possibly_changed = set(changed)
for room_id in rooms_changed: for room_id in rooms_changed:
# Fetch (an approximation) of the current state at the time. # Fetch the current state at the time.
event_rows, token = yield self.store.get_recent_event_ids_for_room( stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
room_id, end_token=from_token.room_key, limit=1,
)
if event_rows: try:
last_event_id = event_rows[-1]["event_id"] event_ids = yield self.store.get_forward_extremeties_for_room(
prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id) room_id, stream_ordering=stream_ordering
else: )
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
except:
prev_state_ids = {} prev_state_ids = {}
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.state.get_current_state_ids(room_id)
@ -266,13 +268,13 @@ class DeviceHandler(BaseHandler):
if not prev_event_id or prev_event_id != event_id: if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key) possibly_changed.add(state_key)
user_ids_changed = set() users_who_share_room = yield self.store.get_users_who_share_room_with_user(
for other_user_id in possibly_changed: user_id
other_rooms = yield self.store.get_rooms_for_user(other_user_id) )
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed) # Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
defer.returnValue(users_who_share_room & possibly_changed)
@defer.inlineCallbacks @defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content): def _incoming_device_list_update(self, origin, edu_content):

View file

@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events(self, user, from_key, room_ids=None, include_offline=True, def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
**kwargs): explicit_room_id=None, **kwargs):
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms. # 2. Get the list of user in the rooms.
@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
user_id = user.to_string() user_id = user.to_string()
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.get_presence_handler() presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache stream_change_cache = self.store.presence_stream_cache
if not room_ids:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(e.room_id for e in rooms)
else:
room_ids = set(room_ids)
max_token = self.store.get_current_presence_token() max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart) plist = yield self.store.get_presence_list_accepted(user.localpart)
friends = set(row["observed_user_id"] for row in plist) users_interested_in = set(row["observed_user_id"] for row in plist)
friends.add(user_id) # So that we receive our own presence users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
# work out if we share a room or they're in our presence list # work out if we share a room or they're in our presence list
get_updates_counter.inc("stream") get_updates_counter.inc("stream")
for other_user_id in changed: for other_user_id in changed:
if other_user_id in friends: if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
continue
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
continue
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.inc("full") get_updates_counter.inc("full")
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id)
user_ids_to_check.update(users)
user_ids_to_check.update(friends)
# Always include yourself. Only really matters for when the user is
# not in any rooms, but still.
user_ids_to_check.add(user_id)
if from_key: if from_key:
user_ids_changed = stream_change_cache.get_entities_changed( user_ids_changed = stream_change_cache.get_entities_changed(
user_ids_to_check, from_key, users_interested_in, from_key,
) )
else: else:
user_ids_changed = user_ids_to_check user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed) updates = yield presence.current_state_for_users(user_ids_changed)

View file

@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler):
token = None token = None
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if need_register: if need_register:
yield self.store.register( yield self.store.register(

View file

@ -437,6 +437,7 @@ class RoomEventSource(object):
limit, limit,
room_ids, room_ids,
is_guest, is_guest,
explicit_room_id=None,
): ):
# We just ignore the key for now. # We just ignore the key for now.

View file

@ -378,6 +378,7 @@ class Notifier(object):
limit=limit, limit=limit,
is_guest=is_peeking, is_guest=is_peeking,
room_ids=room_ids, room_ids=room_ids,
explicit_room_id=explicit_room_id,
) )
if name == "room": if name == "room":

View file

@ -81,7 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name): def __init__(self, hs, app_name):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name self.app_name = app_name
@ -466,7 +466,7 @@ class Mailer(object):
def make_unsubscribe_link(self, user_id, app_id, email_address): def make_unsubscribe_link(self, user_id, app_id, email_address):
params = { params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id), "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id, "app_id": app_id,
"pushkey": email_address, "pushkey": email_address,
} }

View file

@ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore):
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_users_who_share_room_with_user = (
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
)
get_latest_event_ids_in_room = EventFederationStore.__dict__[ get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room" "get_latest_event_ids_in_room"
] ]

View file

@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet):
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
login_token = auth_handler.generate_short_term_login_token(registered_user_id) login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)

View file

@ -193,7 +193,7 @@ class KeyChangesServlet(RestServlet):
) )
defer.returnValue((200, { defer.returnValue((200, {
"changed": changed "changed": list(changed),
})) }))

View file

@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
access_token = self.auth_handler.generate_access_token( access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"] user_id, ["guest = true"]
) )
defer.returnValue((200, { defer.returnValue((200, {

View file

@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
@ -131,6 +131,7 @@ class HomeServer(object):
'federation_transport_client', 'federation_transport_client',
'federation_sender', 'federation_sender',
'receipts_handler', 'receipts_handler',
'macaroon_generator',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -213,6 +214,9 @@ class HomeServer(object):
def build_auth_handler(self): def build_auth_handler(self):
return AuthHandler(self) return AuthHandler(self)
def build_macaroon_generator(self):
return MacaroonGeneartor(self)
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)

View file

@ -280,6 +280,23 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
@cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
rooms = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
for room in rooms:
user_ids = yield self.get_users_in_room(
room.room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
defer.returnValue(user_who_share_room)
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):

View file

@ -478,6 +478,11 @@ class CacheListDescriptor(object):
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
# which namedtuple does for us (i.e. two _CacheContext are the same if
# their caches and keys match). This is important in particular to
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def invalidate(self): def invalidate(self):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)

View file

@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret" token = self.macaroon_generator.generate_access_token("some_user")
token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons # Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks # The most basic of sanity checks
@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect()) self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self): def test_macaroon_caveats(self):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000 self.hs.clock.now = 5000
token = self.auth_handler.generate_access_token("a_user") token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat):
@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.auth_handler.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
) )
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.auth_handler.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)

View file

@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase):
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True)
self.auth_handler = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock() self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_access_token",
])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):