mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 00:23:51 +01:00
Port room rest handlers to async/await
This commit is contained in:
parent
561133c3c5
commit
9be41bc121
1 changed files with 72 additions and 94 deletions
|
@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
|
|||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
|
@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
set_tag("txn_id", txn_id)
|
||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_POST(self, request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
info = yield self._room_creation_handler.create_room(
|
||||
info = await self._room_creation_handler.create_room(
|
||||
requester, self.get_room_config(request)
|
||||
)
|
||||
|
||||
|
@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
def on_PUT_no_state_key(self, request, room_id, event_type):
|
||||
return self.on_PUT(request, room_id, event_type, "")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id, event_type, state_key):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
format = parse_string(
|
||||
request, "format", default="content", allowed_values=["content", "event"]
|
||||
)
|
||||
|
||||
msg_handler = self.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
data = await msg_handler.get_room_data(
|
||||
user_id=requester.user.to_string(),
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
|
@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
elif format == "content":
|
||||
return 200, data.get_dict()["content"]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
if txn_id:
|
||||
set_tag("txn_id", txn_id)
|
||||
|
@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
|
||||
if event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
event = yield self.room_member_handler.update_membership(
|
||||
event = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
target=UserID.from_string(state_key),
|
||||
room_id=room_id,
|
||||
|
@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
content=content,
|
||||
)
|
||||
else:
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
event = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
|
@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict = {
|
||||
|
@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
event = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
|
@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_identifier, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_POST(self, request, room_identifier, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
try:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
elif RoomAlias.is_valid(room_identifier):
|
||||
handler = self.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
|
||||
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=requester.user,
|
||||
room_id=room_id,
|
||||
|
@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
async def on_GET(self, request):
|
||||
server = parse_string(request, "server", default=None)
|
||||
|
||||
try:
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
except InvalidClientCredentialsError as e:
|
||||
# Option to allow servers to require auth when accessing
|
||||
# /publicRooms via CS API. This is especially helpful in private
|
||||
|
@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
if server:
|
||||
data = yield handler.get_remote_public_room_list(
|
||||
data = await handler.get_remote_public_room_list(
|
||||
server, limit=limit, since_token=since_token
|
||||
)
|
||||
else:
|
||||
data = yield handler.get_local_public_room_list(
|
||||
data = await handler.get_local_public_room_list(
|
||||
limit=limit, since_token=since_token
|
||||
)
|
||||
|
||||
return 200, data
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_POST(self, request):
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
server = parse_string(request, "server", default=None)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
if server:
|
||||
data = yield handler.get_remote_public_room_list(
|
||||
data = await handler.get_remote_public_room_list(
|
||||
server,
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
|
@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
third_party_instance_id=third_party_instance_id,
|
||||
)
|
||||
else:
|
||||
data = yield handler.get_local_public_room_list(
|
||||
data = await handler.get_local_public_room_list(
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
search_filter=search_filter,
|
||||
|
@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
|
|||
self.message_handler = hs.get_message_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
async def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
handler = self.message_handler
|
||||
|
||||
# request the state as of a given event, as identified by a stream token,
|
||||
|
@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
|
|||
membership = parse_string(request, "membership")
|
||||
not_membership = parse_string(request, "not_membership")
|
||||
|
||||
events = yield handler.get_state_events(
|
||||
events = await handler.get_state_events(
|
||||
room_id=room_id,
|
||||
user_id=requester.user.to_string(),
|
||||
at_token=at_token,
|
||||
|
@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
|
|||
self.message_handler = hs.get_message_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.message_handler.get_joined_members(
|
||||
users_with_profile = await self.message_handler.get_joined_members(
|
||||
requester, room_id
|
||||
)
|
||||
|
||||
|
@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
self.pagination_handler = hs.get_pagination_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request, default_limit=10)
|
||||
as_client_event = b"raw" not in request.args
|
||||
filter_bytes = parse_string(request, b"filter", encoding=None)
|
||||
|
@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
as_client_event = False
|
||||
else:
|
||||
event_filter = None
|
||||
msgs = yield self.pagination_handler.get_messages(
|
||||
msgs = await self.pagination_handler.get_messages(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
|
@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
|
|||
self.message_handler = hs.get_message_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
# Get all the current state for this room
|
||||
events = yield self.message_handler.get_state_events(
|
||||
events = await self.message_handler.get_state_events(
|
||||
room_id=room_id,
|
||||
user_id=requester.user.to_string(),
|
||||
is_guest=requester.is_guest,
|
||||
|
@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
|
|||
self.initial_sync_handler = hs.get_initial_sync_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.initial_sync_handler.room_initial_sync(
|
||||
content = await self.initial_sync_handler.room_initial_sync(
|
||||
room_id=room_id, requester=requester, pagin_config=pagination_config
|
||||
)
|
||||
return 200, content
|
||||
|
@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id, event_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
try:
|
||||
event = yield self.event_handler.get_event(
|
||||
event = await self.event_handler.get_event(
|
||||
requester.user, room_id, event_id
|
||||
)
|
||||
except AuthError:
|
||||
|
@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
|
|||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
event = yield self._event_serializer.serialize_event(event, time_now)
|
||||
event = await self._event_serializer.serialize_event(event, time_now)
|
||||
return 200, event
|
||||
|
||||
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request, room_id, event_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
limit = parse_integer(request, "limit", default=10)
|
||||
|
||||
|
@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
else:
|
||||
event_filter = None
|
||||
|
||||
results = yield self.room_context_handler.get_event_context(
|
||||
results = await self.room_context_handler.get_event_context(
|
||||
requester.user, room_id, event_id, limit, event_filter
|
||||
)
|
||||
|
||||
|
@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
|
|||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = yield self._event_serializer.serialize_events(
|
||||
results["events_before"] = await self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now
|
||||
)
|
||||
results["event"] = yield self._event_serializer.serialize_event(
|
||||
results["event"] = await self._event_serializer.serialize_event(
|
||||
results["event"], time_now
|
||||
)
|
||||
results["events_after"] = yield self._event_serializer.serialize_events(
|
||||
results["events_after"] = await self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now
|
||||
)
|
||||
results["state"] = yield self._event_serializer.serialize_events(
|
||||
results["state"] = await self._event_serializer.serialize_events(
|
||||
results["state"], time_now
|
||||
)
|
||||
|
||||
|
@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
async def on_POST(self, request, room_id, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
|
||||
yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
|
@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
content = {}
|
||||
|
||||
if membership_action == "invite" and self._has_3pid_invite_keys(content):
|
||||
yield self.room_member_handler.do_3pid_invite(
|
||||
await self.room_member_handler.do_3pid_invite(
|
||||
room_id,
|
||||
requester.user,
|
||||
content["medium"],
|
||||
|
@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
if "reason" in content and membership_action in ["kick", "ban"]:
|
||||
event_content = {"reason": content["reason"]}
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
|
@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
event = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
|
@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet):
|
|||
self.typing_handler = hs.get_typing_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_PUT(self, request, room_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
room_id = urlparse.unquote(room_id)
|
||||
target_user = UserID.from_string(urlparse.unquote(user_id))
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
yield self.presence_handler.bump_presence_active_time(requester.user)
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
# Limit timeout to stop people from setting silly typing timeouts.
|
||||
timeout = min(content.get("timeout", 30000), 120000)
|
||||
|
||||
if content["typing"]:
|
||||
yield self.typing_handler.started_typing(
|
||||
await self.typing_handler.started_typing(
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
yield self.typing_handler.stopped_typing(
|
||||
await self.typing_handler.stopped_typing(
|
||||
target_user=target_user, auth_user=requester.user, room_id=room_id
|
||||
)
|
||||
|
||||
|
@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet):
|
|||
self.handlers = hs.get_handlers()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
async def on_POST(self, request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
batch = parse_string(request, "next_batch")
|
||||
results = yield self.handlers.search_handler.search(
|
||||
results = await self.handlers.search_handler.search(
|
||||
requester.user, content, batch
|
||||
)
|
||||
|
||||
|
@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
async def on_GET(self, request):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
|
||||
return 200, {"joined_rooms": list(room_ids)}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue