forked from MirrorHub/synapse
Introduce a Requester object
This tracks data about the entity which made the request. This is instead of passing around a tuple, which requires call-site modifications every time a new piece of optional context is passed around. I tried to introduce a User object. I gave up.
This commit is contained in:
parent
fcbe63eaad
commit
2110e35fd6
24 changed files with 178 additions and 144 deletions
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||||
from synapse.types import RoomID, UserID, EventID
|
from synapse.types import Requester, RoomID, UserID, EventID
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
|
@ -534,7 +534,9 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
|
|
||||||
defer.returnValue((UserID.from_string(user_id), "", False))
|
defer.returnValue(
|
||||||
|
Requester(UserID.from_string(user_id), "", False)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass # normal users won't have the user_id query parameter set.
|
pass # normal users won't have the user_id query parameter set.
|
||||||
|
@ -564,7 +566,7 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user.to_string()
|
||||||
|
|
||||||
defer.returnValue((user, token_id, is_guest,))
|
defer.returnValue(Requester(user, token_id, is_guest))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||||
|
|
|
@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
is_admin = yield self.auth.is_server_admin(auth_user)
|
auth_user = requester.user
|
||||||
|
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
if not is_admin and target_user != auth_user:
|
if not is_admin and target_user != auth_user:
|
||||||
raise AuthError(403, "You are not a server admin")
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
|
@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# try to auth as a user
|
# try to auth as a user
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
try:
|
try:
|
||||||
user_id = user.to_string()
|
user_id = requester.user.to_string()
|
||||||
yield dir_handler.create_association(
|
yield dir_handler.create_association(
|
||||||
user_id, room_alias, room_id, servers
|
user_id, room_alias, room_id, servers
|
||||||
)
|
)
|
||||||
|
@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
# fallback to default user behaviour if they aren't an AS
|
# fallback to default user behaviour if they aren't an AS
|
||||||
pass
|
pass
|
||||||
|
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user = requester.user
|
||||||
is_admin = yield self.auth.is_server_admin(user)
|
is_admin = yield self.auth.is_server_admin(user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
raise AuthError(403, "You need to be a server admin")
|
raise AuthError(403, "You need to be a server admin")
|
||||||
|
|
|
@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
auth_user, _, is_guest = yield self.auth.get_user_by_req(
|
requester = yield self.auth.get_user_by_req(
|
||||||
request,
|
request,
|
||||||
allow_guest=True
|
allow_guest=True,
|
||||||
)
|
)
|
||||||
|
is_guest = requester.is_guest
|
||||||
room_id = None
|
room_id = None
|
||||||
if is_guest:
|
if is_guest:
|
||||||
if "room_id" not in request.args:
|
if "room_id" not in request.args:
|
||||||
|
@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = "raw" not in request.args
|
||||||
|
|
||||||
chunk = yield handler.get_stream(
|
chunk = yield handler.get_stream(
|
||||||
auth_user.to_string(), pagin_config, timeout=timeout,
|
requester.user.to_string(),
|
||||||
as_client_event=as_client_event, affect_presence=(not is_guest),
|
pagin_config,
|
||||||
room_id=room_id, is_guest=is_guest
|
timeout=timeout,
|
||||||
|
as_client_event=as_client_event,
|
||||||
|
affect_presence=(not is_guest),
|
||||||
|
room_id=room_id,
|
||||||
|
is_guest=is_guest,
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception("Event stream failed")
|
logger.exception("Event stream failed")
|
||||||
|
@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, event_id):
|
def on_GET(self, request, event_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
handler = self.handlers.event_handler
|
handler = self.handlers.event_handler
|
||||||
event = yield handler.get_event(auth_user, event_id)
|
event = yield handler.get_event(requester.user, event_id)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
if event:
|
if event:
|
||||||
|
|
|
@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = "raw" not in request.args
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
include_archived = request.args.get("archived", None) == ["true"]
|
include_archived = request.args.get("archived", None) == ["true"]
|
||||||
content = yield handler.snapshot_all_rooms(
|
content = yield handler.snapshot_all_rooms(
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
as_client_event=as_client_event,
|
as_client_event=as_client_event,
|
||||||
include_archived=include_archived,
|
include_archived=include_archived,
|
||||||
|
|
|
@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
state = yield self.handlers.presence_handler.get_state(
|
state = yield self.handlers.presence_handler.get_state(
|
||||||
target_user=user, auth_user=auth_user)
|
target_user=user, auth_user=requester.user)
|
||||||
|
|
||||||
defer.returnValue((200, state))
|
defer.returnValue((200, state))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
|
@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(400, "Unable to parse state")
|
raise SynapseError(400, "Unable to parse state")
|
||||||
|
|
||||||
yield self.handlers.presence_handler.set_state(
|
yield self.handlers.presence_handler.set_state(
|
||||||
target_user=user, auth_user=auth_user, state=state)
|
target_user=user, auth_user=requester.user, state=state)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
raise SynapseError(400, "User not hosted on this Home Server")
|
raise SynapseError(400, "User not hosted on this Home Server")
|
||||||
|
|
||||||
if auth_user != user:
|
if requester.user != user:
|
||||||
raise SynapseError(400, "Cannot get another user's presence list")
|
raise SynapseError(400, "Cannot get another user's presence list")
|
||||||
|
|
||||||
presence = yield self.handlers.presence_handler.get_presence_list(
|
presence = yield self.handlers.presence_handler.get_presence_list(
|
||||||
|
@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id):
|
def on_POST(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
raise SynapseError(400, "User not hosted on this Home Server")
|
raise SynapseError(400, "User not hosted on this Home Server")
|
||||||
|
|
||||||
if auth_user != user:
|
if requester.user != user:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Cannot modify another user's presence list")
|
400, "Cannot modify another user's presence list")
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_displayname(
|
yield self.handlers.profile_handler.set_displayname(
|
||||||
user, auth_user, new_name)
|
user, requester.user, new_name)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
defer.returnValue((400, "Unable to parse name"))
|
defer.returnValue((400, "Unable to parse name"))
|
||||||
|
|
||||||
yield self.handlers.profile_handler.set_avatar_url(
|
yield self.handlers.profile_handler.set_avatar_url(
|
||||||
user, auth_user, new_name)
|
user, requester.user, new_name)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, e.message)
|
||||||
|
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||||
raise SynapseError(400, "rule_id may not contain slashes")
|
raise SynapseError(400, "rule_id may not contain slashes")
|
||||||
|
@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
if 'attr' in spec:
|
if 'attr' in spec:
|
||||||
self.set_rule_attr(user.to_string(), spec, content)
|
self.set_rule_attr(requester.user, spec, content)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.hs.get_datastore().add_push_rule(
|
yield self.hs.get_datastore().add_push_rule(
|
||||||
user_name=user.to_string(),
|
user_name=requester.user.to_string(),
|
||||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||||
priority_class=priority_class,
|
priority_class=priority_class,
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
|
@ -92,13 +92,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
def on_DELETE(self, request):
|
def on_DELETE(self, request):
|
||||||
spec = _rule_spec_from_path(request.postpath)
|
spec = _rule_spec_from_path(request.postpath)
|
||||||
|
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.hs.get_datastore().delete_push_rule(
|
yield self.hs.get_datastore().delete_push_rule(
|
||||||
user.to_string(), namespaced_rule_id
|
requester.user.to_string(), namespaced_rule_id
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
|
@ -109,7 +109,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user = requester.user
|
||||||
|
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
|
|
|
@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user = requester.user
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
@ -71,7 +72,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||||
try:
|
try:
|
||||||
yield pusher_pool.add_pusher(
|
yield pusher_pool.add_pusher(
|
||||||
user_name=user.to_string(),
|
user_name=user.to_string(),
|
||||||
access_token=token_id,
|
access_token=requester.access_token_id,
|
||||||
profile_tag=content['profile_tag'],
|
profile_tag=content['profile_tag'],
|
||||||
kind=content['kind'],
|
kind=content['kind'],
|
||||||
app_id=content['app_id'],
|
app_id=content['app_id'],
|
||||||
|
|
|
@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
room_config = self.get_room_config(request)
|
room_config = self.get_room_config(request)
|
||||||
info = yield self.make_room(room_config, auth_user, None)
|
info = yield self.make_room(
|
||||||
|
room_config,
|
||||||
|
requester.user,
|
||||||
|
None,
|
||||||
|
)
|
||||||
room_config.update(info)
|
room_config.update(info)
|
||||||
defer.returnValue((200, info))
|
defer.returnValue((200, info))
|
||||||
|
|
||||||
|
@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_type, state_key):
|
def on_GET(self, request, room_id, event_type, state_key):
|
||||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
data = yield msg_handler.get_room_data(
|
data = yield msg_handler.get_room_data(
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
state_key=state_key,
|
state_key=state_key,
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
|
@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
"type": event_type,
|
"type": event_type,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if state_key is not None:
|
if state_key is not None:
|
||||||
|
@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
yield msg_handler.create_and_send_event(
|
yield msg_handler.create_and_send_event(
|
||||||
event_dict, token_id=token_id, txn_id=txn_id,
|
event_dict, token_id=requester.access_token_id, txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_type, txn_id=None):
|
def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||||
user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
|
@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
"type": event_type,
|
"type": event_type,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
},
|
},
|
||||||
token_id=token_id,
|
token_id=requester.access_token_id,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_identifier, txn_id=None):
|
def on_POST(self, request, room_identifier, txn_id=None):
|
||||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
requester = yield self.auth.get_user_by_req(
|
||||||
request,
|
request,
|
||||||
allow_guest=True
|
allow_guest=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the identifier could be a room alias or a room id. Try one then the
|
# the identifier could be a room alias or a room id. Try one then the
|
||||||
|
@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
if is_room_alias:
|
if is_room_alias:
|
||||||
handler = self.handlers.room_member_handler
|
handler = self.handlers.room_member_handler
|
||||||
ret_dict = yield handler.join_room_alias(user, identifier)
|
ret_dict = yield handler.join_room_alias(
|
||||||
|
requester.user,
|
||||||
|
identifier,
|
||||||
|
)
|
||||||
defer.returnValue((200, ret_dict))
|
defer.returnValue((200, ret_dict))
|
||||||
else: # room id
|
else: # room id
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
content = {"membership": Membership.JOIN}
|
content = {"membership": Membership.JOIN}
|
||||||
if is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
yield msg_handler.create_and_send_event(
|
yield msg_handler.create_and_send_event(
|
||||||
{
|
{
|
||||||
"type": EventTypes.Member,
|
"type": EventTypes.Member,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": identifier.to_string(),
|
"room_id": identifier.to_string(),
|
||||||
"sender": user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
"state_key": user.to_string(),
|
"state_key": requester.user.to_string(),
|
||||||
},
|
},
|
||||||
token_id=token_id,
|
token_id=requester.access_token_id,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {"room_id": identifier.to_string()}))
|
defer.returnValue((200, {"room_id": identifier.to_string()}))
|
||||||
|
@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
# TODO support Pagination stream API (limit/tokens)
|
# TODO support Pagination stream API (limit/tokens)
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
events = yield handler.get_state_events(
|
events = yield handler.get_state_events(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk = []
|
chunk = []
|
||||||
|
@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||||
try:
|
try:
|
||||||
presence_handler = self.handlers.presence_handler
|
presence_handler = self.handlers.presence_handler
|
||||||
presence_state = yield presence_handler.get_state(
|
presence_state = yield presence_handler.get_state(
|
||||||
target_user=target_user, auth_user=user
|
target_user=target_user,
|
||||||
|
auth_user=requester.user,
|
||||||
)
|
)
|
||||||
event["content"].update(presence_state)
|
event["content"].update(presence_state)
|
||||||
except:
|
except:
|
||||||
|
@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
pagination_config = PaginationConfig.from_request(
|
pagination_config = PaginationConfig.from_request(
|
||||||
request, default_limit=10,
|
request, default_limit=10,
|
||||||
)
|
)
|
||||||
|
@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
msgs = yield handler.get_messages(
|
msgs = yield handler.get_messages(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
as_client_event=as_client_event
|
as_client_event=as_client_event
|
||||||
)
|
)
|
||||||
|
@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
# Get all the current state for this room
|
# Get all the current state for this room
|
||||||
events = yield handler.get_state_events(
|
events = yield handler.get_state_events(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
defer.returnValue((200, events))
|
defer.returnValue((200, events))
|
||||||
|
|
||||||
|
@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
content = yield self.handlers.message_handler.room_initial_sync(
|
content = yield self.handlers.message_handler.room_initial_sync(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
defer.returnValue((200, content))
|
defer.returnValue((200, content))
|
||||||
|
|
||||||
|
@ -394,12 +402,16 @@ class RoomEventContext(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_id):
|
def on_GET(self, request, room_id, event_id):
|
||||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
limit = int(request.args.get("limit", [10])[0])
|
limit = int(request.args.get("limit", [10])[0])
|
||||||
|
|
||||||
results = yield self.handlers.room_context_handler.get_event_context(
|
results = yield self.handlers.room_context_handler.get_event_context(
|
||||||
user, room_id, event_id, limit, is_guest
|
requester.user,
|
||||||
|
room_id,
|
||||||
|
event_id,
|
||||||
|
limit,
|
||||||
|
requester.is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
@ -429,14 +441,18 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
requester = yield self.auth.get_user_by_req(
|
||||||
request,
|
request,
|
||||||
allow_guest=True
|
allow_guest=True,
|
||||||
)
|
)
|
||||||
|
user = requester.user
|
||||||
|
|
||||||
effective_membership_action = membership_action
|
effective_membership_action = membership_action
|
||||||
|
|
||||||
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
|
if requester.is_guest and membership_action not in {
|
||||||
|
Membership.JOIN,
|
||||||
|
Membership.LEAVE
|
||||||
|
}:
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
@ -451,7 +467,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
content["medium"],
|
content["medium"],
|
||||||
content["address"],
|
content["address"],
|
||||||
content["id_server"],
|
content["id_server"],
|
||||||
token_id,
|
requester.access_token_id,
|
||||||
txn_id
|
txn_id
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
@ -473,7 +489,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
|
|
||||||
content = {"membership": unicode(effective_membership_action)}
|
content = {"membership": unicode(effective_membership_action)}
|
||||||
if is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
|
|
||||||
yield msg_handler.create_and_send_event(
|
yield msg_handler.create_and_send_event(
|
||||||
|
@ -484,9 +500,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
"sender": user.to_string(),
|
"sender": user.to_string(),
|
||||||
"state_key": state_key,
|
"state_key": state_key,
|
||||||
},
|
},
|
||||||
token_id=token_id,
|
token_id=requester.access_token_id,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
if membership_action == "forget":
|
if membership_action == "forget":
|
||||||
|
@ -524,7 +540,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_id, txn_id=None):
|
def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
|
@ -533,10 +549,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||||
"type": EventTypes.Redaction,
|
"type": EventTypes.Redaction,
|
||||||
"content": content,
|
"content": content,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": user.to_string(),
|
"sender": requester.user.to_string(),
|
||||||
"redacts": event_id,
|
"redacts": event_id,
|
||||||
},
|
},
|
||||||
token_id=token_id,
|
token_id=requester.access_token_id,
|
||||||
txn_id=txn_id,
|
txn_id=txn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -564,7 +580,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, room_id, user_id):
|
def on_PUT(self, request, room_id, user_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
room_id = urllib.unquote(room_id)
|
room_id = urllib.unquote(room_id)
|
||||||
target_user = UserID.from_string(urllib.unquote(user_id))
|
target_user = UserID.from_string(urllib.unquote(user_id))
|
||||||
|
@ -576,14 +592,14 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
||||||
if content["typing"]:
|
if content["typing"]:
|
||||||
yield typing_handler.started_typing(
|
yield typing_handler.started_typing(
|
||||||
target_user=target_user,
|
target_user=target_user,
|
||||||
auth_user=auth_user,
|
auth_user=requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
timeout=content.get("timeout", 30000),
|
timeout=content.get("timeout", 30000),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield typing_handler.stopped_typing(
|
yield typing_handler.stopped_typing(
|
||||||
target_user=target_user,
|
target_user=target_user,
|
||||||
auth_user=auth_user,
|
auth_user=requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -597,12 +613,16 @@ class SearchRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
batch = request.args.get("next_batch", [None])[0]
|
batch = request.args.get("next_batch", [None])[0]
|
||||||
results = yield self.handlers.search_handler.search(auth_user, content, batch)
|
results = yield self.handlers.search_handler.search(
|
||||||
|
requester.user,
|
||||||
|
content,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, results))
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
turnUris = self.hs.config.turn_uris
|
turnUris = self.hs.config.turn_uris
|
||||||
turnSecret = self.hs.config.turn_shared_secret
|
turnSecret = self.hs.config.turn_shared_secret
|
||||||
|
@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
|
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
|
||||||
username = "%d:%s" % (expiry, auth_user.to_string())
|
username = "%d:%s" % (expiry, requester.user.to_string())
|
||||||
|
|
||||||
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
|
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
|
||||||
# We need to use standard padded base64 encoding here
|
# We need to use standard padded base64 encoding here
|
||||||
|
|
|
@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet):
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
if LoginType.PASSWORD in result:
|
||||||
# if using password, they should also be logged in
|
# if using password, they should also be logged in
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
requester_user_id = requester.user.to_string()
|
||||||
|
if requester_user_id.to_string() != result[LoginType.PASSWORD]:
|
||||||
raise LoginError(400, "", Codes.UNKNOWN)
|
raise LoginError(400, "", Codes.UNKNOWN)
|
||||||
user_id = auth_user.to_string()
|
user_id = requester_user_id
|
||||||
elif LoginType.EMAIL_IDENTITY in result:
|
elif LoginType.EMAIL_IDENTITY in result:
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
if 'medium' not in threepid or 'address' not in threepid:
|
if 'medium' not in threepid or 'address' not in threepid:
|
||||||
|
@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet):
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||||
auth_user.to_string()
|
requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {'threepids': threepids}))
|
defer.returnValue((200, {'threepids': threepids}))
|
||||||
|
@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet):
|
||||||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||||
threePidCreds = body['threePidCreds']
|
threePidCreds = body['threePidCreds']
|
||||||
|
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||||
|
|
||||||
|
@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
raise SynapseError(500, "Invalid response from ID Server")
|
raise SynapseError(500, "Invalid response from ID Server")
|
||||||
|
|
||||||
yield self.auth_handler.add_threepid(
|
yield self.auth_handler.add_threepid(
|
||||||
auth_user.to_string(),
|
user_id,
|
||||||
threepid['medium'],
|
threepid['medium'],
|
||||||
threepid['address'],
|
threepid['address'],
|
||||||
threepid['validated_at'],
|
threepid['validated_at'],
|
||||||
|
@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet):
|
||||||
if 'bind' in body and body['bind']:
|
if 'bind' in body and body['bind']:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Binding emails %s to %s",
|
"Binding emails %s to %s",
|
||||||
threepid, auth_user.to_string()
|
threepid, user_id
|
||||||
)
|
)
|
||||||
yield self.identity_handler.bind_threepid(
|
yield self.identity_handler.bind_threepid(
|
||||||
threePidCreds, auth_user.to_string()
|
threePidCreds, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
|
@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id, account_data_type):
|
def on_PUT(self, request, user_id, account_data_type):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add account data for other users.")
|
raise AuthError(403, "Cannot add account data for other users.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id, room_id, account_data_type):
|
def on_PUT(self, request, user_id, room_id, account_data_type):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add account data for other users.")
|
raise AuthError(403, "Cannot add account data for other users.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, filter_id):
|
def on_GET(self, request, user_id, filter_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if target_user != auth_user:
|
if target_user != requester.user:
|
||||||
raise AuthError(403, "Cannot get filters for other users")
|
raise AuthError(403, "Cannot get filters for other users")
|
||||||
|
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
|
@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id):
|
def on_POST(self, request, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if target_user != auth_user:
|
if target_user != requester.user:
|
||||||
raise AuthError(403, "Cannot create filters for other users")
|
raise AuthError(403, "Cannot create filters for other users")
|
||||||
|
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
|
|
|
@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = auth_user.to_string()
|
user_id = requester.user.to_string()
|
||||||
# TODO: Check that the device_id matches that in the authentication
|
# TODO: Check that the device_id matches that in the authentication
|
||||||
# or derive the device_id from the authentication instead.
|
# or derive the device_id from the authentication instead.
|
||||||
try:
|
try:
|
||||||
|
@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet):
|
||||||
device_keys = body.get("device_keys", None)
|
device_keys = body.get("device_keys", None)
|
||||||
if device_keys:
|
if device_keys:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Updating device_keys for device %r for user %r at %d",
|
"Updating device_keys for device %r for user %s at %d",
|
||||||
device_id, auth_user, time_now
|
device_id, user_id, time_now
|
||||||
)
|
)
|
||||||
# TODO: Sign the JSON with the server key
|
# TODO: Sign the JSON with the server key
|
||||||
yield self.store.set_e2e_device_keys(
|
yield self.store.set_e2e_device_keys(
|
||||||
|
@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, device_id):
|
def on_GET(self, request, device_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = auth_user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id):
|
def on_GET(self, request, user_id, device_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
auth_user_id = auth_user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
user_id = user_id if user_id else auth_user_id
|
user_id = user_id if user_id else auth_user_id
|
||||||
device_ids = [device_id] if device_id else []
|
device_ids = [device_id] if device_id else []
|
||||||
result = yield self.handle_request(
|
result = yield self.handle_request(
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if receipt_type != "m.read":
|
if receipt_type != "m.read":
|
||||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||||
|
@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet):
|
||||||
yield self.receipts_handler.received_client_receipt(
|
yield self.receipts_handler.received_client_receipt(
|
||||||
room_id,
|
room_id,
|
||||||
receipt_type,
|
receipt_type,
|
||||||
user_id=user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
event_id=event_id
|
event_id=event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, token_id, is_guest = yield self.auth.get_user_by_req(
|
requester = yield self.auth.get_user_by_req(
|
||||||
request, allow_guest=True
|
request, allow_guest=True
|
||||||
)
|
)
|
||||||
|
user = requester.user
|
||||||
|
|
||||||
timeout = parse_integer(request, "timeout", default=0)
|
timeout = parse_integer(request, "timeout", default=0)
|
||||||
since = parse_string(request, "since")
|
since = parse_string(request, "since")
|
||||||
|
@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet):
|
||||||
sync_config = SyncConfig(
|
sync_config = SyncConfig(
|
||||||
user=user,
|
user=user,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
is_guest=is_guest,
|
is_guest=requester.is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
if since is not None:
|
if since is not None:
|
||||||
|
@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet):
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
joined = self.encode_joined(
|
joined = self.encode_joined(
|
||||||
sync_result.joined, filter, time_now, token_id
|
sync_result.joined, filter, time_now, requester.access_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
invited = self.encode_invited(
|
invited = self.encode_invited(
|
||||||
sync_result.invited, filter, time_now, token_id
|
sync_result.invited, filter, time_now, requester.access_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
archived = self.encode_archived(
|
archived = self.encode_archived(
|
||||||
sync_result.archived, filter, time_now, token_id
|
sync_result.archived, filter, time_now, requester.access_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
response_content = {
|
response_content = {
|
||||||
|
|
|
@ -42,8 +42,8 @@ class TagListServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, room_id):
|
def on_GET(self, request, user_id, room_id):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot get tags for other users.")
|
raise AuthError(403, "Cannot get tags for other users.")
|
||||||
|
|
||||||
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
||||||
|
@ -68,8 +68,8 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id, room_id, tag):
|
def on_PUT(self, request, user_id, room_id, tag):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add tags for other users.")
|
raise AuthError(403, "Cannot add tags for other users.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -88,8 +88,8 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, user_id, room_id, tag):
|
def on_DELETE(self, request, user_id, room_id, tag):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != requester.user.to_string():
|
||||||
raise AuthError(403, "Cannot add tags for other users.")
|
raise AuthError(403, "Cannot add tags for other users.")
|
||||||
|
|
||||||
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
||||||
|
|
|
@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def map_request_to_name(self, request):
|
def map_request_to_name(self, request):
|
||||||
# auth the user
|
# auth the user
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
# namespace all file uploads on the user
|
# namespace all file uploads on the user
|
||||||
prefix = base64.urlsafe_b64encode(
|
prefix = base64.urlsafe_b64encode(
|
||||||
auth_user.to_string()
|
requester.user.to_string()
|
||||||
).replace('=', '')
|
).replace('=', '')
|
||||||
|
|
||||||
# use a random string for the main portion
|
# use a random string for the main portion
|
||||||
|
@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource):
|
||||||
file_name = prefix + main_part + suffix
|
file_name = prefix + main_part + suffix
|
||||||
file_path = os.path.join(self.directory, file_name)
|
file_path = os.path.join(self.directory, file_name)
|
||||||
logger.info("User %s is uploading a file to path %s",
|
logger.info("User %s is uploading a file to path %s",
|
||||||
auth_user.to_string(),
|
request.user.user_id.to_string(),
|
||||||
file_path)
|
file_path)
|
||||||
|
|
||||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
# keep trying to make a non-clashing file, with a sensible max attempts
|
||||||
|
|
|
@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
|
||||||
@request_handler
|
@request_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_POST(self, request):
|
def _async_render_POST(self, request):
|
||||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
content_length = request.getHeader("Content-Length")
|
content_length = request.getHeader("Content-Length")
|
||||||
|
@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource):
|
||||||
|
|
||||||
content_uri = yield self.create_content(
|
content_uri = yield self.create_content(
|
||||||
media_type, upload_name, request.content.read(),
|
media_type, upload_name, request.content.read(),
|
||||||
content_length, auth_user
|
content_length, requester.user
|
||||||
)
|
)
|
||||||
|
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
|
|
|
@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
|
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
||||||
|
|
||||||
|
|
||||||
class DomainSpecificString(
|
class DomainSpecificString(
|
||||||
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
||||||
):
|
):
|
||||||
|
|
|
@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self):
|
def test_get_user_by_req_appservice_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args["user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), masquerading_user_id)
|
self.assertEquals(requester.user.to_string(), masquerading_user_id)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
masquerading_user_id = "@doppelganger:matrix.org"
|
masquerading_user_id = "@doppelganger:matrix.org"
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /presence paths."""
|
"""Tests REST events for /presence paths."""
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.rest.client.v1 import presence
|
from synapse.rest.client.v1 import presence
|
||||||
from synapse.rest.client.v1 import events
|
from synapse.rest.client.v1 import events
|
||||||
from synapse.types import UserID
|
from synapse.types import Requester, UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
hs.get_clock().time_msec.return_value = 1000000
|
hs.get_clock().time_msec.return_value = 1000000
|
||||||
|
|
||||||
def _get_user_by_req(req=None, allow_guest=False):
|
def _get_user_by_req(req=None, allow_guest=False):
|
||||||
return (UserID.from_string(myid), "", False)
|
return Requester(UserID.from_string(myid), "", False)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
|
@ -14,16 +14,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests REST events for /profile paths."""
|
"""Tests REST events for /profile paths."""
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from mock import Mock, NonCallableMock
|
from mock import Mock
|
||||||
|
|
||||||
from ....utils import MockHttpResource, setup_test_homeserver
|
from ....utils import MockHttpResource, setup_test_homeserver
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.types import UserID
|
from synapse.types import Requester, UserID
|
||||||
|
|
||||||
from synapse.rest.client.v1 import profile
|
from synapse.rest.client.v1 import profile
|
||||||
|
|
||||||
|
@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return (UserID.from_string(myid), "", False)
|
return Requester(UserID.from_string(myid), "", False)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue