From 1a0997bbd5f9134b21b1105c6b14e09480f193a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:53:06 +0000 Subject: [PATCH 1/3] Port rest/v1 to async/await --- synapse/app/synchrotron.py | 2 +- synapse/rest/client/v1/directory.py | 53 ++++++++++------------- synapse/rest/client/v1/events.py | 18 +++----- synapse/rest/client/v1/initial_sync.py | 8 ++-- synapse/rest/client/v1/login.py | 60 +++++++++++--------------- synapse/rest/client/v1/logout.py | 20 ++++----- synapse/rest/client/v1/presence.py | 18 +++----- synapse/rest/client/v1/profile.py | 48 +++++++++------------ synapse/rest/client/v1/push_rule.py | 24 +++++------ synapse/rest/client/v1/pusher.py | 27 +++++------- synapse/rest/client/v1/voip.py | 7 +-- 11 files changed, 118 insertions(+), 167 deletions(-) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index b14da09f47..288ee64b42 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -151,7 +151,7 @@ class SynchrotronPresence(object): def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? - pass + return defer.succeed(None) get_states = __func__(PresenceHandler.get_states) get_state = __func__(PresenceHandler.get_state) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4ea3666874..5934b1fe8b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_alias): + async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler - res = yield dir_handler.get_association(room_alias) + res = await dir_handler.get_association(room_alias) return 200, res - @defer.inlineCallbacks - def on_PUT(self, request, room_alias): + async def on_PUT(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) @@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet): # TODO(erikj): Check types. - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Room does not exist") - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.create_association( + await self.handlers.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_alias): + async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = yield self.auth.get_appservice_by_req(request) + service = await self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association(service, room_alias) + await dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet): # fallback to default user behaviour if they aren't an AS pass - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user = requester.user room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association(requester, room_alias) + await dir_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - room = yield self.store.get_room(room_id) + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - @defer.inlineCallbacks - def on_PUT(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, visibility ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): def on_DELETE(self, request, network_id, room_id): return self._edit(request, network_id, room_id, "private") - @defer.inlineCallbacks - def _edit(self, request, network_id, room_id, visibility): - requester = yield self.auth.get_user_by_req(request) + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( 403, "Only appservices can edit the appservice published room list" ) - yield self.handlers.directory_handler.edit_published_appservice_room_list( + await self.handlers.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 6651b4cf07..4beb617733 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -16,8 +16,6 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet): self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: @@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in request.args - chunk = yield self.event_stream_handler.get_stream( + chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -83,14 +80,13 @@ class EventRestServlet(RestServlet): self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request, event_id): - requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, None, event_id) + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event else: return 404, "Event not found." diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 2da3cd7511..910b3b4eeb 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_boolean from synapse.rest.client.v2_alpha._base import client_patterns @@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) - content = yield self.initial_sync_handler.snapshot_all_rooms( + content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19eb15003d..ff9c978fe7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET from six.moves import urllib -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError @@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): self._address_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), @@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet): if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE ): - result = yield self.do_jwt_login(login_submission) + result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = yield self.do_token_login(login_submission) + result = await self.do_token_login(login_submission) else: - result = yield self._do_other_login(login_submission) + result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - @defer.inlineCallbacks - def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: @@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet): ( canonical_user_id, callback_3pid, - ) = yield self.auth_handler.check_password_provider_3pid( + ) = await self.auth_handler.check_password_provider_3pid( medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result # No password providers were able to handle this 3pid # Check local store - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( medium, address ) if not user_id: @@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet): ) try: - canonical_user_id, callback = yield self.auth_handler.validate_login( + canonical_user_id, callback = await self.auth_handler.validate_login( identifier["user"], login_submission ) except LoginError: @@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet): ) raise - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback ) return result - @defer.inlineCallbacks - def _complete_login( + async def _complete_login( self, user_id, login_submission, callback=None, create_non_existant_users=False ): """Called when we've successfully authed the user and now need to @@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet): ) if create_non_existant_users: - user_id = yield self.auth_handler.check_user_exists(user_id) + user_id = await self.auth_handler.check_user_exists(user_id) if not user_id: - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name ) @@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet): } if callback is not None: - yield callback(result) + await callback(result) return result - @defer.inlineCallbacks - def do_token_login(self, login_submission): + async def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = yield self._complete_login(user_id, login_submission) + result = await self._complete_login(user_id, login_submission) return result - @defer.inlineCallbacks - def do_jwt_login(self, login_submission): + async def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( @@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login( + result = await self._complete_login( user_id, login_submission, create_non_existant_users=True ) return result @@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet): self._sso_auth_handler = SSOAuthHandler(hs) self._http_client = hs.get_proxied_http_client() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) uri = self.cas_server_url + "/proxyValidate" args = { @@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet): "service": self.cas_service_url, } try: - body = yield self._http_client.get_raw(uri, args) + body = await self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data body = pde.response - result = yield self.handle_cas_response(request, body, client_redirect_url) + result = await self.handle_cas_response(request, body, client_redirect_url) return result def handle_cas_response(self, request, cas_response_body, client_redirect_url): @@ -555,8 +548,7 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - @defer.inlineCallbacks - def on_successful_auth( + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. @@ -582,9 +574,9 @@ class SSOAuthHandler(object): """ localpart = map_username_to_mxid_localpart(username) user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = yield self._auth_handler.check_user_exists(user_id) + registered_user_id = await self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id = yield self._registration_handler.register_user( + registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=user_display_name ) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 4785a34d75..1cf3caf832 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + await self._auth_handler.delete_access_token(access_token) else: - yield self._device_handler.delete_device( + await self._device_handler.delete_device( requester.user.to_string(), requester.device_id ) @@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # first delete all of the user's devices - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # .. and then delete any access tokens which weren't associated with # devices. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 0153525cef..eec16f8ad8 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,8 +19,6 @@ import logging from six import string_types -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.presence_handler.is_visible( + allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.presence_handler.get_state(target_user=user) + state = await self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) return 200, state - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: @@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self.hs.config.use_presence: - yield self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index bbce2e2b71..1eac8a44c5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -14,7 +14,6 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/ """ -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns @@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) + displayname = await self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} @@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: @@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) + await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) return 200, {} @@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9f8c3d09e3..4f74600239 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import ( NotFoundError, @@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - @defer.inlineCallbacks - def on_PUT(self, request, path): + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet): except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") @@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if "attr" in spec: - yield self.set_rule_attr(user_id, spec, content) + await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} @@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet): after = _namespaced_rule_id(spec, after) try: - yield self.store.add_push_rule( + await self.store.add_push_rule( user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, @@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, path): + async def on_DELETE(self, request, path): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") spec = _rule_spec_from_path([x for x in path.split("/")]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule(user_id, namespaced_rule_id) + await self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) return 200, {} except StoreError as e: @@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet): else: raise - @defer.inlineCallbacks - def on_GET(self, request, path): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 41660682d9..0791866f55 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet): self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user content = parse_json_object_from_request(request) @@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet): and "kind" in content and content["kind"] is None ): - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) return 200, {} @@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet): append = content["append"] if not append: - yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], pushkey=content["pushkey"], not_user_id=user.to_string(), ) try: - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet): self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) try: - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 2afdbb89e5..747d46eac2 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -17,8 +17,6 @@ import base64 import hashlib import hmac -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) From 4ca3ef10b9a8d15cf351d67d574088d944c2e3b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:53:10 +0000 Subject: [PATCH 2/3] Fixup tests --- tests/rest/client/v1/test_presence.py | 3 +++ tests/rest/client/v1/test_profile.py | 10 +++++++++- tests/utils.py | 4 +++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..12c5e95cb5 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index de2ac1ed33..c57da59191 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -461,7 +461,9 @@ class MockHttpResource(HttpServer): try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) From 410bfd035a5f2b77ad94d297f689fc29b9197218 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:54:15 +0000 Subject: [PATCH 3/3] Newsfile --- changelog.d/6482.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6482.misc diff --git a/changelog.d/6482.misc b/changelog.d/6482.misc new file mode 100644 index 0000000000..bdef9cf40a --- /dev/null +++ b/changelog.d/6482.misc @@ -0,0 +1 @@ +Port synapse.rest.client.v1 to async/await.