0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-02 18:59:04 +02:00

Introduce a Requester object

This tracks data about the entity which made the request. This is
instead of passing around a tuple, which requires call-site
modifications every time a new piece of optional context is passed
around.

I tried to introduce a User object. I gave up.
This commit is contained in:
Daniel Wagner-Hall 2016-01-11 15:29:57 +00:00
parent fcbe63eaad
commit 2110e35fd6
24 changed files with 178 additions and 144 deletions

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from unpaddedbase64 import decode_base64
@ -534,7 +534,9 @@ class Auth(object):
request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "", False))
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
@ -564,7 +566,7 @@ class Auth(object):
request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id, is_guest,))
defer.returnValue(Requester(user, token_id, is_guest))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",

View file

@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
requester = yield self.auth.get_user_by_req(request)
auth_user = requester.user
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin")

View file

@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
try:
# try to auth as a user
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
try:
user_id = user.to_string()
user_id = requester.user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")

View file

@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, _, is_guest = yield self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request,
allow_guest=True
allow_guest=True,
)
is_guest = requester.is_guest
room_id = None
if is_guest:
if "room_id" not in request.args:
@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
requester.user.to_string(),
pagin_config,
timeout=timeout,
as_client_event=as_client_event,
affect_presence=(not is_guest),
room_id=room_id,
is_guest=is_guest,
)
except:
logger.exception("Event stream failed")
@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:

View file

@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms(
user_id=user.to_string(),
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
include_archived=include_archived,

View file

@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user)
target_user=user, auth_user=requester.user)
defer.returnValue((200, state))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = {}
@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=auth_user, state=state)
target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {}))
@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
if requester.user != user:
raise SynapseError(
400, "Cannot modify another user's presence list")

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
try:
@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name)
user, requester.user, new_name)
defer.returnValue((200, {}))
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url(
user, auth_user, new_name)
user, requester.user, new_name)
defer.returnValue((200, {}))

View file

@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request)
if 'attr' in spec:
self.set_rule_attr(user.to_string(), spec, content)
self.set_rule_attr(requester.user, spec, content)
defer.returnValue((200, {}))
try:
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
user_name=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
@ -92,13 +92,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), namespaced_rule_id
requester.user.to_string(), namespaced_rule_id
)
defer.returnValue((200, {}))
except StoreError as e:
@ -109,7 +109,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is

View file

@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
user, token_id, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user
content = _parse_json(request)
@ -71,7 +72,7 @@ class PusherRestServlet(ClientV1RestServlet):
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
access_token=token_id,
access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'],
app_id=content['app_id'],

View file

@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
info = yield self.make_room(
room_config,
requester.user,
None,
)
room_config.update(info)
defer.returnValue((200, info))
@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
user_id=user.to_string(),
user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
state_key=state_key,
is_guest=is_guest,
is_guest=requester.is_guest,
)
if not data:
@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"sender": requester.user.to_string(),
}
if state_key is not None:
@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
event_dict, token_id=token_id, txn_id=txn_id,
event_dict, token_id=requester.access_token_id, txn_id=txn_id,
)
defer.returnValue((200, {}))
@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"sender": requester.user.to_string(),
},
token_id=token_id,
token_id=requester.access_token_id,
txn_id=txn_id,
)
@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request,
allow_guest=True
allow_guest=True,
)
# the identifier could be a room alias or a room id. Try one then the
@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if is_room_alias:
handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(user, identifier)
ret_dict = yield handler.join_room_alias(
requester.user,
identifier,
)
defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
if is_guest:
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
"sender": requester.user.to_string(),
"state_key": requester.user.to_string(),
},
token_id=token_id,
token_id=requester.access_token_id,
txn_id=txn_id,
is_guest=is_guest,
is_guest=requester.is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string(),
user_id=requester.user.to_string(),
)
chunk = []
@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user, auth_user=user
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
user_id=user.to_string(),
is_guest=is_guest,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
pagin_config=pagination_config,
as_client_event=as_client_event
)
@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string(),
is_guest=is_guest,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
)
defer.returnValue((200, events))
@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
user_id=user.to_string(),
user_id=requester.user.to_string(),
pagin_config=pagination_config,
is_guest=is_guest,
is_guest=requester.is_guest,
)
defer.returnValue((200, content))
@ -394,12 +402,16 @@ class RoomEventContext(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit, is_guest
requester.user,
room_id,
event_id,
limit,
requester.is_guest,
)
time_now = self.clock.time_msec()
@ -429,14 +441,18 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request,
allow_guest=True
allow_guest=True,
)
user = requester.user
effective_membership_action = membership_action
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE
}:
raise AuthError(403, "Guest access not allowed")
content = _parse_json(request)
@ -451,7 +467,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"],
content["address"],
content["id_server"],
token_id,
requester.access_token_id,
txn_id
)
defer.returnValue((200, {}))
@ -473,7 +489,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler
content = {"membership": unicode(effective_membership_action)}
if is_guest:
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
@ -484,9 +500,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"sender": user.to_string(),
"state_key": state_key,
},
token_id=token_id,
token_id=requester.access_token_id,
txn_id=txn_id,
is_guest=is_guest,
is_guest=requester.is_guest,
)
if membership_action == "forget":
@ -524,7 +540,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@ -533,10 +549,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"sender": requester.user.to_string(),
"redacts": event_id,
},
token_id=token_id,
token_id=requester.access_token_id,
txn_id=txn_id,
)
@ -564,7 +580,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))
@ -576,14 +592,14 @@ class RoomTypingRestServlet(ClientV1RestServlet):
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,
auth_user=auth_user,
auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
target_user=target_user,
auth_user=auth_user,
auth_user=requester.user,
room_id=room_id,
)
@ -597,12 +613,16 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search(auth_user, content, batch)
results = yield self.handlers.search_handler.search(
requester.user,
content,
batch,
)
defer.returnValue((200, results))

View file

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, auth_user.to_string())
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here

View file

@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
if requester_user_id.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
user_id = requester_user_id
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request):
yield run_on_reactor()
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string()
requester.user.to_string()
)
defer.returnValue((200, {'threepids': threepids}))
@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid(
auth_user.to_string(),
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet):
if 'bind' in body and body['bind']:
logger.debug(
"Binding emails %s to %s",
threepid, auth_user.to_string()
threepid, user_id
)
yield self.identity_handler.bind_threepid(
threePidCreds, auth_user.to_string()
threePidCreds, user_id
)
defer.returnValue((200, {}))

View file

@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:

View file

@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):

View file

@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
try:
@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet):
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %r at %d",
device_id, auth_user, time_now
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
requester = yield self.auth.get_user_by_req(request)
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(

View file

@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet):
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,
user_id=user.to_string(),
user_id=requester.user.to_string(),
event_id=event_id
)

View file

@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, token_id, is_guest = yield self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request, allow_guest=True
)
user = requester.user
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet):
sync_config = SyncConfig(
user=user,
filter=filter,
is_guest=is_guest,
is_guest=requester.is_guest,
)
if since is not None:
@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
joined = self.encode_joined(
sync_result.joined, filter, time_now, token_id
sync_result.joined, filter, time_now, requester.access_token_id
)
invited = self.encode_invited(
sync_result.invited, filter, time_now, token_id
sync_result.invited, filter, time_now, requester.access_token_id
)
archived = self.encode_archived(
sync_result.archived, filter, time_now, token_id
sync_result.archived, filter, time_now, requester.access_token_id
)
response_content = {

View file

@ -42,8 +42,8 @@ class TagListServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
@ -68,8 +68,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
@ -88,8 +88,8 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)

View file

@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
auth_user.to_string()
requester.user.to_string()
).replace('=', '')
# use a random string for the main portion
@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource):
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
auth_user.to_string(),
request.user.user_id.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts

View file

@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource):
content_uri = yield self.create_content(
media_type, upload_name, request.content.read(),
content_length, auth_user
content_length, requester.user
)
respond_with_json(

View file

@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain"))
):

View file

@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id)
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"

View file

@ -14,7 +14,6 @@
# limitations under the License.
"""Tests REST events for /presence paths."""
from tests import unittest
from twisted.internet import defer
@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events
from synapse.types import UserID
from synapse.types import Requester, UserID
from synapse.util.async import run_on_reactor
from collections import namedtuple
@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None, allow_guest=False):
return (UserID.from_string(myid), "", False)
return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req

View file

@ -14,16 +14,15 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
from tests import unittest
from twisted.internet import defer
from mock import Mock, NonCallableMock
from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile
@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
return (UserID.from_string(myid), "", False)
return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req