Introduce a Requester object

This tracks data about the entity which made the request. This is
instead of passing around a tuple, which requires call-site
modifications every time a new piece of optional context is passed
around.

I tried to introduce a User object. I gave up.
This commit is contained in:
Daniel Wagner-Hall 2016-01-11 15:29:57 +00:00
parent fcbe63eaad
commit 2110e35fd6
24 changed files with 178 additions and 144 deletions

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -534,7 +534,9 @@ class Auth(object):
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "", False)) defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
return return
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
@ -564,7 +566,7 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id, is_guest,)) defer.returnValue(Requester(user, token_id, is_guest))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",

View file

@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) auth_user = requester.user
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")

View file

@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
try: try:
# try to auth as a user # try to auth as a user
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
user_id = user.to_string() user_id = requester.user.to_string()
yield dir_handler.create_association( yield dir_handler.create_association(
user_id, room_alias, room_id, servers user_id, room_alias, room_id, servers
) )
@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:
raise AuthError(403, "You need to be a server admin") raise AuthError(403, "You need to be a server admin")

View file

@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _, is_guest = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
request, request,
allow_guest=True allow_guest=True,
) )
is_guest = requester.is_guest
room_id = None room_id = None
if is_guest: if is_guest:
if "room_id" not in request.args: if "room_id" not in request.args:
@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
chunk = yield handler.get_stream( chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout, requester.user.to_string(),
as_client_event=as_client_event, affect_presence=(not is_guest), pagin_config,
room_id=room_id, is_guest=is_guest timeout=timeout,
as_client_event=as_client_event,
affect_presence=(not is_guest),
room_id=room_id,
is_guest=is_guest,
) )
except: except:
logger.exception("Event stream failed") logger.exception("Event stream failed")
@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View file

@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"] include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms( content = yield handler.snapshot_all_rooms(
user_id=user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event, as_client_event=as_client_event,
include_archived=include_archived, include_archived=include_archived,

View file

@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user) target_user=user, auth_user=requester.user)
defer.returnValue((200, state)) defer.returnValue((200, state))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state( yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=auth_user, state=state) target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user: if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list") raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list( presence = yield self.handlers.presence_handler.get_presence_list(
@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user: if requester.user != user:
raise SynapseError( raise SynapseError(
400, "Cannot modify another user's presence list") 400, "Cannot modify another user's presence list")

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name) user, requester.user, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, auth_user, new_name) user, requester.user, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
if 'attr' in spec: if 'attr' in spec:
self.set_rule_attr(user.to_string(), spec, content) self.set_rule_attr(requester.user, spec, content)
defer.returnValue((200, {})) defer.returnValue((200, {}))
try: try:
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(), user_name=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -92,13 +92,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request): def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.hs.get_datastore().delete_push_rule(
user.to_string(), namespaced_rule_id requester.user.to_string(), namespaced_rule_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
@ -109,7 +109,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is

View file

@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, token_id, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user
content = _parse_json(request) content = _parse_json(request)
@ -71,7 +72,7 @@ class PusherRestServlet(ClientV1RestServlet):
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_name=user.to_string(),
access_token=token_id, access_token=requester.access_token_id,
profile_tag=content['profile_tag'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],

View file

@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(
room_config,
requester.user,
None,
)
room_config.update(info) room_config.update(info)
defer.returnValue((200, info)) defer.returnValue((200, info))
@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
user_id=user.to_string(), user_id=requester.user.to_string(),
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
is_guest=is_guest, is_guest=requester.is_guest,
) )
if not data: if not data:
@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": requester.user.to_string(),
} }
if state_key is not None: if state_key is not None:
@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
event_dict, token_id=token_id, txn_id=txn_id, event_dict, token_id=requester.access_token_id, txn_id=txn_id,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": requester.user.to_string(),
}, },
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
request, request,
allow_guest=True allow_guest=True,
) )
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if is_room_alias: if is_room_alias:
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(user, identifier) ret_dict = yield handler.join_room_alias(
requester.user,
identifier,
)
defer.returnValue((200, ret_dict)) defer.returnValue((200, ret_dict))
else: # room id else: # room id
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN} content = {"membership": Membership.JOIN}
if is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": content, "content": content,
"room_id": identifier.to_string(), "room_id": identifier.to_string(),
"sender": user.to_string(), "sender": requester.user.to_string(),
"state_key": user.to_string(), "state_key": requester.user.to_string(),
}, },
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest, is_guest=requester.is_guest,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=requester.user.to_string(),
) )
chunk = [] chunk = []
@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
try: try:
presence_handler = self.handlers.presence_handler presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state( presence_state = yield presence_handler.get_state(
target_user=target_user, auth_user=user target_user=target_user,
auth_user=requester.user,
) )
event["content"].update(presence_state) event["content"].update(presence_state)
except: except:
@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=requester.user.to_string(),
is_guest=is_guest, is_guest=requester.is_guest,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event
) )
@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=requester.user.to_string(),
is_guest=is_guest, is_guest=requester.is_guest,
) )
defer.returnValue((200, events)) defer.returnValue((200, events))
@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
is_guest=is_guest, is_guest=requester.is_guest,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -394,12 +402,16 @@ class RoomEventContext(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit, is_guest requester.user,
room_id,
event_id,
limit,
requester.is_guest,
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -429,14 +441,18 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
request, request,
allow_guest=True allow_guest=True,
) )
user = requester.user
effective_membership_action = membership_action effective_membership_action = membership_action
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}: if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE
}:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
content = _parse_json(request) content = _parse_json(request)
@ -451,7 +467,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
token_id, requester.access_token_id,
txn_id txn_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -473,7 +489,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": unicode(effective_membership_action)} content = {"membership": unicode(effective_membership_action)}
if is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
@ -484,9 +500,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
}, },
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest, is_guest=requester.is_guest,
) )
if membership_action == "forget": if membership_action == "forget":
@ -524,7 +540,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -533,10 +549,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": requester.user.to_string(),
"redacts": event_id, "redacts": event_id,
}, },
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -564,7 +580,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
@ -576,14 +592,14 @@ class RoomTypingRestServlet(ClientV1RestServlet):
if content["typing"]: if content["typing"]:
yield typing_handler.started_typing( yield typing_handler.started_typing(
target_user=target_user, target_user=target_user,
auth_user=auth_user, auth_user=requester.user,
room_id=room_id, room_id=room_id,
timeout=content.get("timeout", 30000), timeout=content.get("timeout", 30000),
) )
else: else:
yield typing_handler.stopped_typing( yield typing_handler.stopped_typing(
target_user=target_user, target_user=target_user,
auth_user=auth_user, auth_user=requester.user,
room_id=room_id, room_id=room_id,
) )
@ -597,12 +613,16 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
batch = request.args.get("next_batch", [None])[0] batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search(auth_user, content, batch) results = yield self.handlers.search_handler.search(
requester.user,
content,
batch,
)
defer.returnValue((200, results)) defer.returnValue((200, results))

View file

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret
@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, auth_user.to_string()) username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here # We need to use standard padded base64 encoding here

View file

@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
# if using password, they should also be logged in # if using password, they should also be logged in
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]: requester_user_id = requester.user.to_string()
if requester_user_id.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN) raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string() user_id = requester_user_id
elif LoginType.EMAIL_IDENTITY in result: elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid: if 'medium' not in threepid or 'address' not in threepid:
@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
yield run_on_reactor() yield run_on_reactor()
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string() requester.user.to_string()
) )
defer.returnValue((200, {'threepids': threepids})) defer.returnValue((200, {'threepids': threepids}))
@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds'] threePidCreds = body['threePidCreds']
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(500, "Invalid response from ID Server") raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
auth_user.to_string(), user_id,
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],
threepid['validated_at'], threepid['validated_at'],
@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']: if 'bind' in body and body['bind']:
logger.debug( logger.debug(
"Binding emails %s to %s", "Binding emails %s to %s",
threepid, auth_user.to_string() threepid, user_id
) )
yield self.identity_handler.bind_threepid( yield self.identity_handler.bind_threepid(
threePidCreds, auth_user.to_string() threePidCreds, user_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type): def on_PUT(self, request, user_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: try:
@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type): def on_PUT(self, request, user_id, room_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
try: try:

View file

@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):

View file

@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
try: try:
@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet):
device_keys = body.get("device_keys", None) device_keys = body.get("device_keys", None)
if device_keys: if device_keys:
logger.info( logger.info(
"Updating device_keys for device %r for user %r at %d", "Updating device_keys for device %r for user %s at %d",
device_id, auth_user, time_now device_id, user_id, time_now
) )
# TODO: Sign the JSON with the server key # TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = requester.user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.handle_request( result = yield self.handle_request(

View file

@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet):
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,
user_id=user.to_string(), user_id=requester.user.to_string(),
event_id=event_id event_id=event_id
) )

View file

@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, token_id, is_guest = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
request, allow_guest=True request, allow_guest=True
) )
user = requester.user
timeout = parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since") since = parse_string(request, "since")
@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet):
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,
filter=filter, filter=filter,
is_guest=is_guest, is_guest=requester.is_guest,
) )
if since is not None: if since is not None:
@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, filter, time_now, token_id sync_result.joined, filter, time_now, requester.access_token_id
) )
invited = self.encode_invited( invited = self.encode_invited(
sync_result.invited, filter, time_now, token_id sync_result.invited, filter, time_now, requester.access_token_id
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, filter, time_now, token_id sync_result.archived, filter, time_now, requester.access_token_id
) )
response_content = { response_content = {

View file

@ -42,8 +42,8 @@ class TagListServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, room_id): def on_GET(self, request, user_id, room_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = yield self.store.get_tags_for_room(user_id, room_id)
@ -68,8 +68,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag): def on_PUT(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
try: try:
@ -88,8 +88,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag): def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)

View file

@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(
auth_user.to_string() requester.user.to_string()
).replace('=', '') ).replace('=', '')
# use a random string for the main portion # use a random string for the main portion
@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource):
file_name = prefix + main_part + suffix file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name) file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s", logger.info("User %s is uploading a file to path %s",
auth_user.to_string(), request.user.user_id.to_string(),
file_path) file_path)
# keep trying to make a non-clashing file, with a sensible max attempts # keep trying to make a non-clashing file, with a sensible max attempts

View file

@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")
@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource):
content_uri = yield self.create_content( content_uri = yield self.create_content(
media_type, upload_name, request.content.read(), media_type, upload_name, request.content.read(),
content_length, auth_user content_length, requester.user
) )
respond_with_json( respond_with_json(

View file

@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):

View file

@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id) self.assertEquals(requester.user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org" masquerading_user_id = "@doppelganger:matrix.org"

View file

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /presence paths.""" """Tests REST events for /presence paths."""
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.types import UserID from synapse.types import Requester, UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from collections import namedtuple from collections import namedtuple
@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False): def _get_user_by_req(req=None, allow_guest=False):
return (UserID.from_string(myid), "", False) return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View file

@ -14,16 +14,15 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /profile paths.""" """Tests REST events for /profile paths."""
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, NonCallableMock from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile from synapse.rest.client.v1 import profile
@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
return (UserID.from_string(myid), "", False) return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req