Port room rest handlers to async/await

This commit is contained in:
Erik Johnston 2019-10-29 13:09:24 +00:00
parent 561133c3c5
commit 9be41bc121

View file

@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request) return self.txns.fetch_or_execute_request(request, self.on_POST, request)
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
info = yield self._room_creation_handler.create_room( info = await self._room_creation_handler.create_room(
requester, self.get_room_config(request) requester, self.get_room_config(request)
) )
@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT_no_state_key(self, request, room_id, event_type): def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "") return self.on_PUT(request, room_id, event_type, "")
@defer.inlineCallbacks async def on_GET(self, request, room_id, event_type, state_key):
def on_GET(self, request, room_id, event_type, state_key): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string( format = parse_string(
request, "format", default="content", allowed_values=["content", "event"] request, "format", default="content", allowed_values=["content", "event"]
) )
msg_handler = self.message_handler msg_handler = self.message_handler
data = yield msg_handler.get_room_data( data = await msg_handler.get_room_data(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content": elif format == "content":
return 200, data.get_dict()["content"] return 200, data.get_dict()["content"]
@defer.inlineCallbacks async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if txn_id: if txn_id:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
membership = content.get("membership", None) membership = content.get("membership", None)
event = yield self.room_member_handler.update_membership( event = await self.room_member_handler.update_membership(
requester, requester,
target=UserID.from_string(state_key), target=UserID.from_string(state_key),
room_id=room_id, room_id=room_id,
@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content, content=content,
) )
else: else:
event = yield self.event_creation_handler.create_and_send_nonmember_event( event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True) register_txn_path(self, PATTERNS, http_server, with_get=True)
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_type, txn_id=None):
def on_POST(self, request, room_id, event_type, txn_id=None): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = { event_dict = {
@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service: if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event = yield self.event_creation_handler.create_and_send_nonmember_event( event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)" PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks async def on_POST(self, request, room_identifier, txn_id=None):
def on_POST(self, request, room_identifier, txn_id=None): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
try: try:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string() room_id = room_id.to_string()
else: else:
raise SynapseError( raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=requester.user, target=requester.user,
room_id=room_id, room_id=room_id,
@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
server = parse_string(request, "server", default=None) server = parse_string(request, "server", default=None)
try: try:
yield self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e: except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing # Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private # /publicRooms via CS API. This is especially helpful in private
@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token server, limit=limit, since_token=since_token
) )
else: else:
data = yield handler.get_local_public_room_list( data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token limit=limit, since_token=since_token
) )
return 200, data return 200, data
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None) server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = await handler.get_remote_public_room_list(
server, server,
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
third_party_instance_id=third_party_instance_id, third_party_instance_id=third_party_instance_id,
) )
else: else:
data = yield handler.get_local_public_room_list( data = await handler.get_local_public_room_list(
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
handler = self.message_handler handler = self.message_handler
# request the state as of a given event, as identified by a stream token, # request the state as of a given event, as identified by a stream token,
@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership") membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership") not_membership = parse_string(request, "not_membership")
events = yield handler.get_state_events( events = await handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
at_token=at_token, at_token=at_token,
@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
users_with_profile = yield self.message_handler.get_joined_members( users_with_profile = await self.message_handler.get_joined_members(
requester, room_id requester, room_id
) )
@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10) pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None) filter_bytes = parse_string(request, b"filter", encoding=None)
@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
as_client_event = False as_client_event = False
else: else:
event_filter = None event_filter = None
msgs = yield self.pagination_handler.get_messages( msgs = await self.pagination_handler.get_messages(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room # Get all the current state for this room
events = yield self.message_handler.get_state_events( events = await self.message_handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
is_guest=requester.is_guest, is_guest=requester.is_guest,
@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.initial_sync_handler.room_initial_sync( content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config room_id=room_id, requester=requester, pagin_config=pagination_config
) )
return 200, content return 200, content
@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id, event_id):
def on_GET(self, request, room_id, event_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
try: try:
event = yield self.event_handler.get_event( event = await self.event_handler.get_event(
requester.user, room_id, event_id requester.user, room_id, event_id
) )
except AuthError: except AuthError:
@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:
event = yield self._event_serializer.serialize_event(event, time_now) event = await self._event_serializer.serialize_event(event, time_now)
return 200, event return 200, event
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id, event_id):
def on_GET(self, request, room_id, event_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
else: else:
event_filter = None event_filter = None
results = yield self.room_context_handler.get_event_context( results = await self.room_context_handler.get_event_context(
requester.user, room_id, event_id, limit, event_filter requester.user, room_id, event_id, limit, event_filter
) )
@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
results["events_before"] = yield self._event_serializer.serialize_events( results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now results["events_before"], time_now
) )
results["event"] = yield self._event_serializer.serialize_event( results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now results["event"], time_now
) )
results["events_after"] = yield self._event_serializer.serialize_events( results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now results["events_after"], time_now
) )
results["state"] = yield self._event_serializer.serialize_events( results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now results["state"], time_now
) )
@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks async def on_POST(self, request, room_id, txn_id=None):
def on_POST(self, request, room_id, txn_id=None): requester = await self.auth.get_user_by_req(request, allow_guest=False)
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
yield self.room_member_handler.forget(user=requester.user, room_id=room_id) await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {} return 200, {}
@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
) )
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks async def on_POST(self, request, room_id, membership_action, txn_id=None):
def on_POST(self, request, room_id, membership_action, txn_id=None): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in { if requester.is_guest and membership_action not in {
Membership.JOIN, Membership.JOIN,
@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {} content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.room_member_handler.do_3pid_invite( await self.room_member_handler.do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
content["medium"], content["medium"],
@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content and membership_action in ["kick", "ban"]: if "reason" in content and membership_action in ["kick", "ban"]:
event_content = {"reason": content["reason"]} event_content = {"reason": content["reason"]}
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_id, txn_id=None):
def on_POST(self, request, room_id, event_id, txn_id=None): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event = yield self.event_creation_handler.create_and_send_nonmember_event( event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_PUT(self, request, room_id, user_id):
def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
room_id = urlparse.unquote(room_id) room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id)) target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
yield self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts. # Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000) timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]: if content["typing"]:
yield self.typing_handler.started_typing( await self.typing_handler.started_typing(
target_user=target_user, target_user=target_user,
auth_user=requester.user, auth_user=requester.user,
room_id=room_id, room_id=room_id,
timeout=timeout, timeout=timeout,
) )
else: else:
yield self.typing_handler.stopped_typing( await self.typing_handler.stopped_typing(
target_user=target_user, auth_user=requester.user, room_id=room_id target_user=target_user, auth_user=requester.user, room_id=room_id
) )
@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch") batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search( results = await self.handlers.search_handler.search(
requester.user, content, batch requester.user, content, batch
) )
@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)} return 200, {"joined_rooms": list(room_ids)}