Port rest.client.v2

This commit is contained in:
Erik Johnston 2019-12-05 16:46:37 +00:00
parent af5d0ebc72
commit 9c41ba4c5f
23 changed files with 361 additions and 505 deletions

View file

@ -78,7 +78,7 @@ def interactive_auth_handler(orig):
""" """
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs) res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth) res.addErrback(_catch_incomplete_interactive_auth)
return res return res

View file

@ -18,8 +18,6 @@ import logging
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
@ -67,8 +65,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -95,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email "email", email
) )
@ -106,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -115,7 +112,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send password reset emails from Synapse # Send password reset emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -153,8 +150,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
[self.config.email_password_reset_template_failure_html], [self.config.email_password_reset_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request, medium):
def on_GET(self, request, medium):
# We currently only handle threepid token submissions for email # We currently only handle threepid token submissions for email
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
@ -176,7 +172,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -218,8 +214,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# there are two possibilities here. Either the user does not have an # there are two possibilities here. Either the user does not have an
@ -233,14 +228,14 @@ class PasswordRestServlet(RestServlet):
# In the second case, we require a password to confirm their identity. # In the second case, we require a password to confirm their identity.
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth( params = await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
user_id = requester.user.to_string() user_id = requester.user.to_string()
else: else:
requester = None requester = None
result, params, _ = yield self.auth_handler.check_auth( result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
) )
@ -254,7 +249,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
threepid["address"] = threepid["address"].lower() threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid( threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"] threepid["medium"], threepid["address"]
) )
if not threepid_user_id: if not threepid_user_id:
@ -267,7 +262,7 @@ class PasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
yield self._set_password_handler.set_password(user_id, new_password, requester) await self._set_password_handler.set_password(user_id, new_password, requester)
return 200, {} return 200, {}
@ -286,8 +281,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
erase = body.get("erase", False) erase = body.get("erase", False)
if not isinstance(erase, bool): if not isinstance(erase, bool):
@ -297,19 +291,19 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester.app_service: if requester.app_service:
yield self._deactivate_account_handler.deactivate_account( await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase requester.user.to_string(), erase
) )
return 200, {} return 200, {}
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
result = yield self._deactivate_account_handler.deactivate_account( result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server") requester.user.to_string(), erase, id_server=body.get("id_server")
) )
if result: if result:
@ -346,8 +340,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -371,7 +364,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.store.get_user_id_by_threepid( existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"] "email", body["email"]
) )
@ -382,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -391,7 +384,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send threepid validation emails from Synapse # Send threepid validation emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -414,8 +407,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"] body, ["client_secret", "country", "phone_number", "send_attempt"]
@ -435,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None: if existing_user_id is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@ -450,7 +442,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
"Adding phone numbers to user account is not supported by this homeserver", "Adding phone numbers to user account is not supported by this homeserver",
) )
ret = yield self.identity_handler.requestMsisdnToken( ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn, self.hs.config.account_threepid_delegate_msisdn,
country, country,
phone_number, phone_number,
@ -484,8 +476,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
[self.config.email_add_threepid_template_failure_html], [self.config.email_add_threepid_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -508,7 +499,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -558,8 +549,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if not self.config.account_threepid_delegate_msisdn: if not self.config.account_threepid_delegate_msisdn:
raise SynapseError( raise SynapseError(
400, 400,
@ -571,7 +561,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
assert_params_in_dict(body, ["client_secret", "sid", "token"]) assert_params_in_dict(body, ["client_secret", "sid", "token"])
# Proxy submit_token request to msisdn threepid delegate # Proxy submit_token request to msisdn threepid delegate
response = yield self.identity_handler.proxy_msisdn_submit_token( response = await self.identity_handler.proxy_msisdn_submit_token(
self.config.account_threepid_delegate_msisdn, self.config.account_threepid_delegate_msisdn,
body["client_secret"], body["client_secret"],
body["sid"], body["sid"],
@ -591,17 +581,15 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids} return 200, {"threepids": threepids}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -615,11 +603,11 @@ class ThreepidRestServlet(RestServlet):
client_secret = threepid_creds["client_secret"] client_secret = threepid_creds["client_secret"]
sid = threepid_creds["sid"] sid = threepid_creds["sid"]
validation_session = yield self.identity_handler.validate_threepid_session( validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid client_secret, sid
) )
if validation_session: if validation_session:
yield self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, user_id,
validation_session["medium"], validation_session["medium"],
validation_session["address"], validation_session["address"],
@ -643,9 +631,8 @@ class ThreepidAddRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -653,15 +640,15 @@ class ThreepidAddRestServlet(RestServlet):
client_secret = body["client_secret"] client_secret = body["client_secret"]
sid = body["sid"] sid = body["sid"]
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
validation_session = yield self.identity_handler.validate_threepid_session( validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid client_secret, sid
) )
if validation_session: if validation_session:
yield self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, user_id,
validation_session["medium"], validation_session["medium"],
validation_session["address"], validation_session["address"],
@ -683,8 +670,7 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@ -693,10 +679,10 @@ class ThreepidBindRestServlet(RestServlet):
client_secret = body["client_secret"] client_secret = body["client_secret"]
id_access_token = body.get("id_access_token") # optional id_access_token = body.get("id_access_token") # optional
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.identity_handler.bind_threepid( await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token client_secret, sid, user_id, id_server, id_access_token
) )
@ -713,12 +699,11 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
"""Unbind the given 3pid from a specific identity server, or identity servers that are """Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound known to have this 3pid bound
""" """
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"]) assert_params_in_dict(body, ["medium", "address"])
@ -728,7 +713,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to # Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past # unbind from all identity servers this threepid has been added to in the past
result = yield self.identity_handler.try_unbind_threepid( result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(), requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server}, {"address": address, "medium": medium, "id_server": id_server},
) )
@ -743,16 +728,15 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"]) assert_params_in_dict(body, ["medium", "address"])
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
ret = yield self.auth_handler.delete_threepid( ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server") user_id, body["medium"], body["address"], body.get("id_server")
) )
except Exception: except Exception:
@ -777,9 +761,8 @@ class WhoamiRestServlet(RestServlet):
super(WhoamiRestServlet, self).__init__() super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
return 200, {"user_id": requester.user.to_string()} return 200, {"user_id": requester.user.to_string()}

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -41,15 +39,14 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, account_data_type):
def on_PUT(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = yield self.store.add_account_data_for_user( max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body user_id, account_data_type, body
) )
@ -57,13 +54,12 @@ class AccountDataServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_GET(self, request, user_id, account_data_type):
def on_GET(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_global_account_data_by_type_for_user( event = await self.store.get_global_account_data_by_type_for_user(
account_data_type, user_id account_data_type, user_id
) )
@ -91,9 +87,8 @@ class RoomAccountDataServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, room_id, account_data_type):
def on_PUT(self, request, user_id, room_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -106,7 +101,7 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers", " Use /rooms/!roomId:server.name/read_markers",
) )
max_id = yield self.store.add_account_data_to_room( max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
@ -114,13 +109,12 @@ class RoomAccountDataServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_GET(self, request, user_id, room_id, account_data_type):
def on_GET(self, request, user_id, room_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_account_data_for_room_and_type( event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type user_id, room_id, account_data_type
) )

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet):
self.success_html = hs.config.account_validity.account_renewed_html_content self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if b"token" not in request.args: if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token") raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0] renewal_token = request.args[b"token"][0]
token_valid = yield self.account_activity_handler.renew_account( token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8") renewal_token.decode("utf8")
) )
@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8")) request.write(response.encode("utf8"))
finish_request(request) finish_request(request)
defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity self.account_validity = self.hs.config.account_validity
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled: if not self.account_validity.renew_by_email_enabled:
raise AuthError( raise AuthError(
403, "Account renewal via email is disabled on this server." 403, "Account renewal via email is disabled on this server."
) )
requester = yield self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id) await self.account_activity_handler.send_renewal_email_to_user(user_id)
defer.returnValue((200, {})) return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
@ -171,8 +169,7 @@ class AuthRestServlet(RestServlet):
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")
@defer.inlineCallbacks async def on_POST(self, request, stagetype):
def on_POST(self, request, stagetype):
session = parse_string(request, "session") session = parse_string(request, "session")
if not session: if not session:
@ -186,7 +183,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session} authdict = {"response": response, "session": session}
success = yield self.auth_handler.add_oob_auth( success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
) )
@ -215,7 +212,7 @@ class AuthRestServlet(RestServlet):
session = request.args["session"][0] session = request.args["session"][0]
authdict = {"session": session} authdict = {"session": session}
success = yield self.auth_handler.add_oob_auth( success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
) )

View file

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = await self.store.get_user_by_id(requester.user.to_string())
user = yield self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"]) change_password = bool(user["password_hash"])
response = { response = {

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user(
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string() requester.user.to_string()
) )
return 200, {"devices": devices} return 200, {"devices": devices}
@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -84,11 +80,11 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"]) assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
yield self.device_handler.delete_devices( await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"] requester.user.to_string(), body["devices"]
) )
return 200, {} return 200, {}
@ -108,18 +104,16 @@ class DeviceRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks async def on_GET(self, request, device_id):
def on_GET(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = await self.device_handler.get_device(
device = yield self.device_handler.get_device(
requester.user.to_string(), device_id requester.user.to_string(), device_id
) )
return 200, device return 200, device
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_DELETE(self, request, device_id):
def on_DELETE(self, request, device_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -132,19 +126,18 @@ class DeviceRestServlet(RestServlet):
else: else:
raise raise
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
yield self.device_handler.delete_device(requester.user.to_string(), device_id) await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_PUT(self, request, device_id):
def on_PUT(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.device_handler.update_device( await self.device_handler.update_device(
requester.user.to_string(), device_id, body requester.user.to_string(), device_id, body
) )
return 200, {} return 200, {}

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@defer.inlineCallbacks async def on_GET(self, request, user_id, filter_id):
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if target_user != requester.user: if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
@ -52,7 +49,7 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
try: try:
filter_collection = yield self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id user_localpart=target_user.localpart, filter_id=filter_id
) )
except StoreError as e: except StoreError as e:
@ -72,11 +69,10 @@ class CreateFilterRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@defer.inlineCallbacks async def on_POST(self, request, user_id):
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if target_user != requester.user: if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
@ -87,7 +83,7 @@ class CreateFilterRestServlet(RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
filter_id = yield self.filtering.add_user_filter( filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content user_localpart=target_user.localpart, user_filter=content
) )

View file

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID from synapse.types import GroupID
@ -38,24 +36,22 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile( group_description = await self.groups_handler.get_group_profile(
group_id, requester_user_id group_id, requester_user_id
) )
return 200, group_description return 200, group_description
@defer.inlineCallbacks async def on_POST(self, request, group_id):
def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile( await self.groups_handler.update_group_profile(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary( get_group_summary = await self.groups_handler.get_group_summary(
group_id, requester_user_id group_id, requester_user_id
) )
@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, category_id, room_id):
def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room( resp = await self.groups_handler.update_group_summary_room(
group_id, group_id,
requester_user_id, requester_user_id,
room_id=room_id, room_id=room_id,
@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, category_id, room_id):
def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room( resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id group_id, requester_user_id, room_id=room_id, category_id=category_id
) )
@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id, category_id):
def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category( category = await self.groups_handler.get_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
return 200, category return 200, category
@defer.inlineCallbacks async def on_PUT(self, request, group_id, category_id):
def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category( resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content group_id, requester_user_id, category_id=category_id, content=content
) )
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, category_id):
def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category( resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories( category = await self.groups_handler.get_group_categories(
group_id, requester_user_id group_id, requester_user_id
) )
@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id, role_id):
def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role( category = await self.groups_handler.get_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
return 200, category return 200, category
@defer.inlineCallbacks async def on_PUT(self, request, group_id, role_id):
def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role( resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content group_id, requester_user_id, role_id=role_id, content=content
) )
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, role_id):
def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role( resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles( category = await self.groups_handler.get_group_roles(
group_id, requester_user_id group_id, requester_user_id
) )
@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, role_id, user_id):
def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_user( resp = await self.groups_handler.update_group_summary_user(
group_id, group_id,
requester_user_id, requester_user_id,
user_id=user_id, user_id=user_id,
@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, role_id, user_id):
def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_user( resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id group_id, requester_user_id, user_id=user_id, role_id=role_id
) )
@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group( result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group( result = await self.groups_handler.get_users_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group( result = await self.groups_handler.get_invited_users_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.set_group_join_policy( result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
# TODO: Create group on remote server # TODO: Create group on remote server
@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart") localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string() group_id = GroupID(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group( result = await self.groups_handler.create_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, room_id):
def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group( result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content group_id, requester_user_id, room_id, content
) )
return 200, result return 200, result
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, room_id):
def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group( result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id group_id, requester_user_id, room_id
) )
@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, room_id, config_key):
def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.update_room_in_group( result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content group_id, requester_user_id, room_id, config_key, content
) )
@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks async def on_PUT(self, request, group_id, user_id):
def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
config = content.get("config", {}) config = content.get("config", {})
result = yield self.groups_handler.invite( result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config group_id, user_id, requester_user_id, config
) )
@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, user_id):
def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content group_id, requester_user_id, requester_user_id, content
) )
@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.join_group( result = await self.groups_handler.join_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.accept_invite( result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
publicise = content["publicise"] publicise = content["publicise"]
yield self.store.update_group_publicity(group_id, requester_user_id, publicise) await self.store.update_group_publicity(group_id, requester_user_id, publicise)
return 200, {} return 200, {}
@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
result = yield self.groups_handler.get_publicised_groups_for_user(user_id) result = await self.groups_handler.get_publicised_groups_for_user(user_id)
return 200, result return 200, result
@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
user_ids = content["user_ids"] user_ids = content["user_ids"]
result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
return 200, result return 200, result
@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(requester_user_id) result = await self.groups_handler.get_joined_groups(requester_user_id)
return 200, result return 200, result

View file

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -71,9 +69,8 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@trace(opname="upload_keys") @trace(opname="upload_keys")
@defer.inlineCallbacks async def on_POST(self, request, device_id):
def on_POST(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -103,7 +100,7 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating" 400, "To upload keys, you must pass device_id when authenticating"
) )
result = yield self.e2e_keys_handler.upload_keys_for_user( result = await self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body user_id, device_id, body
) )
return 200, result return 200, result
@ -154,13 +151,12 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
return 200, result return 200, result
@ -185,9 +181,8 @@ class KeyChangesServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from") from_token_string = parse_string(request, "from")
set_tag("from", from_token_string) set_tag("from", from_token_string)
@ -200,7 +195,7 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
results = yield self.device_handler.get_user_ids_changed(user_id, from_token) results = await self.device_handler.get_user_ids_changed(user_id, from_token)
return 200, results return 200, results
@ -231,12 +226,11 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
return 200, result return 200, result
@ -263,17 +257,16 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result return 200, result
@ -315,13 +308,12 @@ class SignaturesUploadServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( result = await self.e2e_keys_handler.upload_signatures_for_device_keys(
user_id, body user_id, body
) )
return 200, result return 200, result

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False) from_token = parse_string(request, "from", required=False)
@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet):
limit = min(limit, 500) limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user( push_actions = await self.store.get_push_actions_for_user(
user_id, from_token, limit, only_highlight=(only == "highlight") user_id, from_token, limit, only_highlight=(only == "highlight")
) )
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read" user_id, "m.read"
) )
notif_event_ids = [pa["event_id"] for pa in push_actions] notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_events = yield self.store.get_events(notif_event_ids) notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = [] returned_push_actions = []
@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet):
"actions": pa["actions"], "actions": pa["actions"],
"ts": pa["received_ts"], "ts": pa["received_ts"],
"event": ( "event": (
yield self._event_serializer.serialize_event( await self._event_serializer.serialize_event(
notif_events[pa["event_id"]], notif_events[pa["event_id"]],
self.clock.time_msec(), self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,

View file

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
@defer.inlineCallbacks async def on_POST(self, request, user_id):
def on_POST(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.") raise AuthError(403, "Cannot request tokens for other users.")
@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet):
token = random_string(24) token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
return ( return (
200, 200,

View file

@ -20,8 +20,6 @@ from typing import List, Union
from six import string_types from six import string_types
from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -102,8 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config: if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -129,7 +126,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"] "email", body["email"]
) )
@ -140,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -149,7 +146,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send registration emails from Synapse # Send registration emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -175,8 +172,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
@ -197,7 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn "msisdn", msisdn
) )
@ -215,7 +211,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Registration by phone number is not supported on this homeserver" 400, "Registration by phone number is not supported on this homeserver"
) )
ret = yield self.identity_handler.requestMsisdnToken( ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn, self.hs.config.account_threepid_delegate_msisdn,
country, country,
phone_number, phone_number,
@ -258,8 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
[self.config.email_registration_template_failure_html], [self.config.email_registration_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request, medium):
def on_GET(self, request, medium):
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
400, "This medium is currently not supported for registration" 400, "This medium is currently not supported for registration"
@ -280,7 +275,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -338,8 +333,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
), ),
) )
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if not self.hs.config.enable_registration: if not self.hs.config.enable_registration:
raise SynapseError( raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@ -347,11 +341,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
ip = self.hs.get_ip_from_request(request) ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred: with self.ratelimiter.ratelimit(ip) as wait_deferred:
yield wait_deferred await wait_deferred
username = parse_string(request, "username", required=True) username = parse_string(request, "username", required=True)
yield self.registration_handler.check_username(username) await self.registration_handler.check_username(username)
return 200, {"available": True} return 200, {"available": True}
@ -382,8 +376,7 @@ class RegisterRestServlet(RestServlet):
) )
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP() client_addr = request.getClientIP()
@ -408,7 +401,7 @@ class RegisterRestServlet(RestServlet):
kind = request.args[b"kind"][0] kind = request.args[b"kind"][0]
if kind == b"guest": if kind == b"guest":
ret = yield self._do_guest_registration(body, address=client_addr) ret = await self._do_guest_registration(body, address=client_addr)
return ret return ret
elif kind != b"user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
@ -435,7 +428,7 @@ class RegisterRestServlet(RestServlet):
appservice = None appservice = None
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = await self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely # fork off as soon as possible for ASes which have completely
# different registration flows to normal users # different registration flows to normal users
@ -455,7 +448,7 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = await self._do_appservice_registration(
desired_username, access_token, body desired_username, access_token, body
) )
return 200, result # we throw for non 200 responses return 200, result # we throw for non 200 responses
@ -495,13 +488,13 @@ class RegisterRestServlet(RestServlet):
) )
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( await self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
assigned_user_id=registered_user_id, assigned_user_id=registered_user_id,
) )
auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows, body, self.hs.get_ip_from_request(request) self._registration_flows, body, self.hs.get_ip_from_request(request)
) )
@ -557,7 +550,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"] medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"] address = auth_result[login_type]["address"]
existing_user_id = yield self.store.get_user_id_by_threepid( existing_user_id = await self.store.get_user_id_by_threepid(
medium, address medium, address
) )
@ -568,7 +561,7 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE, Codes.THREEPID_IN_USE,
) )
registered_user_id = yield self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password=new_password, password=new_password,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
@ -581,7 +574,7 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved( if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid self.hs.config.mau_limits_reserved_threepids, threepid
): ):
yield self.store.upsert_monthly_active_user(registered_user_id) await self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)
@ -591,12 +584,12 @@ class RegisterRestServlet(RestServlet):
registered = True registered = True
return_dict = yield self._create_registration_details( return_dict = await self._create_registration_details(
registered_user_id, params registered_user_id, params
) )
if registered: if registered:
yield self.registration_handler.post_registration_actions( await self.registration_handler.post_registration_actions(
user_id=registered_user_id, user_id=registered_user_id,
auth_result=auth_result, auth_result=auth_result,
access_token=return_dict.get("access_token"), access_token=return_dict.get("access_token"),
@ -607,15 +600,13 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks async def _do_appservice_registration(self, username, as_token, body):
def _do_appservice_registration(self, username, as_token, body): user_id = await self.registration_handler.appservice_register(
user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
return (yield self._create_registration_details(user_id, body)) return await self._create_registration_details(user_id, body)
@defer.inlineCallbacks async def _create_registration_details(self, user_id, params):
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token. Allocates device_id if one was not given; also creates access_token.
@ -631,18 +622,17 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False user_id, device_id, initial_display_name, is_guest=False
) )
result.update({"access_token": access_token, "device_id": device_id}) result.update({"access_token": access_token, "device_id": device_id})
return result return result
@defer.inlineCallbacks async def _do_guest_registration(self, params, address=None):
def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled") raise SynapseError(403, "Guest access is disabled")
user_id = yield self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
make_guest=True, address=address make_guest=True, address=address
) )
@ -650,7 +640,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True user_id, device_id, initial_display_name, is_guest=True
) )

View file

@ -21,8 +21,6 @@ any time to reflect changes in the MSC.
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet):
request, self.on_PUT_or_POST, request, *args, **kwargs request, self.on_PUT_or_POST, request, *args, **kwargs
) )
@defer.inlineCallbacks async def on_PUT_or_POST(
def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None self, request, room_id, parent_id, relation_type, event_type, txn_id=None
): ):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it # Add relations to a membership is meaningless, so we just deny it
@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
event = yield self.event_creation_handler.create_and_send_nonmember_event( event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id requester, event_dict=event_dict, txn_id=txn_id
) )
@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): self, request, room_id, parent_id, relation_type=None, event_type=None
requester = yield self.auth.get_user_by_req(request, allow_guest=True) ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it. # b) the user is allowed to view it.
event = yield self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token = parse_string(request, "from")
@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = RelationPaginationToken.from_string(to_token) to_token = RelationPaginationToken.from_string(to_token)
pagination_chunk = yield self.store.get_relations_for_event( pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = yield self.store.get_events_as_list( events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk] [c["event_id"] for c in pagination_chunk.chunk]
) )
@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet):
# We set bundle_aggregations to False when retrieving the original # We set bundle_aggregations to False when retrieving the original
# event because we want the content before relations were applied to # event because we want the content before relations were applied to
# it. # it.
original_event = yield self._event_serializer.serialize_event( original_event = await self._event_serializer.serialize_event(
event, now, bundle_aggregations=False event, now, bundle_aggregations=False
) )
# Similarly, we don't allow relations to be applied to relations, so we # Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them # return the original relations without any aggregations on top of them
# here. # here.
events = yield self._event_serializer.serialize_events( events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False events, now, bundle_aggregations=False
) )
@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): self, request, room_id, parent_id, relation_type=None, event_type=None
requester = yield self.auth.get_user_by_req(request, allow_guest=True) ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
event = yield self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = AggregationPaginationToken.from_string(to_token) to_token = AggregationPaginationToken.from_string(to_token)
pagination_chunk = yield self.store.get_aggregation_groups_for_event( pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id, event_id=parent_id,
event_type=event_type, event_type=event_type,
limit=limit, limit=limit,
@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id) await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = RelationPaginationToken.from_string(to_token) to_token = RelationPaginationToken.from_string(to_token)
result = yield self.store.get_relations_for_event( result = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = yield self.store.get_events_as_list( events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk] [c["event_id"] for c in result.chunk]
) )
now = self.clock.time_msec() now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(events, now) events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict() return_value = result.to_dict()
return_value["chunk"] = events return_value["chunk"] = events

View file

@ -18,8 +18,6 @@ import logging
from six import string_types from six import string_types
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_id):
def on_POST(self, request, room_id, event_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
yield self.store.add_event_report( await self.store.add_event_report(
room_id=room_id, room_id=room_id,
event_id=event_id, event_id=event_id,
user_id=user_id, user_id=user_id,

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_PUT(self, request, room_id, session_id):
def on_PUT(self, request, room_id, session_id):
""" """
Uploads one or more encrypted E2E room keys for backup purposes. Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional) room_id: the ID of the room the keys are for (optional)
@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet):
} }
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
version = parse_string(request, "version") version = parse_string(request, "version")
@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet):
if room_id: if room_id:
body = {"rooms": {room_id: body}} body = {"rooms": {room_id: body}}
ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_GET(self, request, room_id, session_id):
def on_GET(self, request, room_id, session_id):
""" """
Retrieves one or more encrypted E2E room keys for backup purposes. Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API. Symmetric with the PUT version of the API.
@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet):
} }
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version") version = parse_string(request, "version")
room_keys = yield self.e2e_room_keys_handler.get_room_keys( room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys return 200, room_keys
@defer.inlineCallbacks async def on_DELETE(self, request, room_id, session_id):
def on_DELETE(self, request, room_id, session_id):
""" """
Deletes one or more encrypted E2E room keys for a user for backup purposes. Deletes one or more encrypted E2E room keys for a user for backup purposes.
@ -235,11 +230,11 @@ class RoomKeysServlet(RestServlet):
the version must already have been created via the /change_secret API. the version must already have been created via the /change_secret API.
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version") version = parse_string(request, "version")
ret = yield self.e2e_room_keys_handler.delete_room_keys( ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
return 200, ret return 200, ret
@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
""" """
Create a new backup version for this user's room_keys with the given Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user info. The version is allocated by the server and returned to the user
@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet):
"version": 12345 "version": 12345
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
info = parse_json_object_from_request(request) info = parse_json_object_from_request(request)
new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) new_version = await self.e2e_room_keys_handler.create_version(user_id, info)
return 200, {"version": new_version} return 200, {"version": new_version}
# we deliberately don't have a PUT /version, as these things really should # we deliberately don't have a PUT /version, as these things really should
@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_GET(self, request, version):
def on_GET(self, request, version):
""" """
Retrieve the version information about a given version of the user's Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the room_keys backup. If the version part is missing, returns info about the
@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet):
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) info = await self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e: except SynapseError as e:
if e.code == 404: if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND) raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info return 200, info
@defer.inlineCallbacks async def on_DELETE(self, request, version):
def on_DELETE(self, request, version):
""" """
Delete the information about a given version of the user's Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most room_keys backup. If the version part is missing, deletes the most
@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet):
if version is None: if version is None:
raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.e2e_room_keys_handler.delete_version(user_id, version) await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_PUT(self, request, version):
def on_PUT(self, request, version):
""" """
Update the information about a given version of the user's room_keys backup. Update the information about a given version of the user's room_keys backup.
@ -382,7 +373,7 @@ class RoomKeysVersionServlet(RestServlet):
Content-Type: application/json Content-Type: application/json
{} {}
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
info = parse_json_object_from_request(request) info = parse_json_object_from_request(request)
@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet):
400, "No version specified to update", Codes.MISSING_PARAM 400, "No version specified to update", Codes.MISSING_PARAM
) )
yield self.e2e_room_keys_handler.update_version(user_id, version, info) await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {} return 200, {}

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -59,9 +57,8 @@ class RoomUpgradeRestServlet(RestServlet):
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth() self._auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self._auth.get_user_by_req(request)
requester = yield self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",)) assert_params_in_dict(content, ("new_version",))
@ -74,7 +71,7 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION, Codes.UNSUPPORTED_ROOM_VERSION,
) )
new_room_id = yield self._room_creation_handler.upgrade_room( new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version requester, room_id, new_version
) )

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace from synapse.logging.opentracing import set_tag, trace
@ -51,15 +49,14 @@ class SendToDeviceRestServlet(servlet.RestServlet):
request, self._put, request, message_type, txn_id request, self._put, request, message_type, txn_id
) )
@defer.inlineCallbacks async def _put(self, request, message_type, txn_id):
def _put(self, request, message_type, txn_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
sender_user_id = requester.user.to_string() sender_user_id = requester.user.to_string()
yield self.device_message_handler.send_device_message( await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"] sender_user_id, message_type, content["messages"]
) )

View file

@ -18,8 +18,6 @@ import logging
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if b"from" in request.args: if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'. # /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'. # Lets be helpful and whine if we see a 'from'.
@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?" 400, "'from' is not a valid query parameter. Did you mean 'since'?"
) )
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user user = requester.user
device_id = requester.device_id device_id = requester.device_id
@ -138,7 +135,7 @@ class SyncRestServlet(RestServlet):
filter_collection = FilterCollection(filter_object) filter_collection = FilterCollection(filter_object)
else: else:
try: try:
filter_collection = yield self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(
user.localpart, filter_id user.localpart, filter_id
) )
except StoreError as err: except StoreError as err:
@ -161,20 +158,20 @@ class SyncRestServlet(RestServlet):
since_token = None since_token = None
# send any outstanding server notices to the user. # send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(user.to_string()) await self._server_notices_sender.on_user_syncing(user.to_string())
affect_presence = set_presence != PresenceState.OFFLINE affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence: if affect_presence:
yield self.presence_handler.set_state( await self.presence_handler.set_state(
user, {"presence": set_presence}, True user, {"presence": set_presence}, True
) )
context = yield self.presence_handler.user_syncing( context = await self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence user.to_string(), affect_presence=affect_presence
) )
with context: with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
sync_config, sync_config,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,
@ -182,14 +179,13 @@ class SyncRestServlet(RestServlet):
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
response_content = yield self.encode_response( response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection time_now, sync_result, requester.access_token_id, filter_collection
) )
return 200, response_content return 200, response_content
@defer.inlineCallbacks async def encode_response(self, time_now, sync_result, access_token_id, filter):
def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == "client": if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation": elif filter.event_format == "federation":
@ -197,7 +193,7 @@ class SyncRestServlet(RestServlet):
else: else:
raise Exception("Unknown event format %s" % (filter.event_format,)) raise Exception("Unknown event format %s" % (filter.event_format,))
joined = yield self.encode_joined( joined = await self.encode_joined(
sync_result.joined, sync_result.joined,
time_now, time_now,
access_token_id, access_token_id,
@ -205,11 +201,11 @@ class SyncRestServlet(RestServlet):
event_formatter, event_formatter,
) )
invited = yield self.encode_invited( invited = await self.encode_invited(
sync_result.invited, time_now, access_token_id, event_formatter sync_result.invited, time_now, access_token_id, event_formatter
) )
archived = yield self.encode_archived( archived = await self.encode_archived(
sync_result.archived, sync_result.archived,
time_now, time_now,
access_token_id, access_token_id,
@ -250,8 +246,9 @@ class SyncRestServlet(RestServlet):
] ]
} }
@defer.inlineCallbacks async def encode_joined(
def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): self, rooms, time_now, token_id, event_fields, event_formatter
):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -272,7 +269,7 @@ class SyncRestServlet(RestServlet):
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = yield self.encode_room( joined[room.room_id] = await self.encode_room(
room, room,
time_now, time_now,
token_id, token_id,
@ -283,8 +280,7 @@ class SyncRestServlet(RestServlet):
return joined return joined
@defer.inlineCallbacks async def encode_invited(self, rooms, time_now, token_id, event_formatter):
def encode_invited(self, rooms, time_now, token_id, event_formatter):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
@ -304,7 +300,7 @@ class SyncRestServlet(RestServlet):
""" """
invited = {} invited = {}
for room in rooms: for room in rooms:
invite = yield self._event_serializer.serialize_event( invite = await self._event_serializer.serialize_event(
room.invite, room.invite,
time_now, time_now,
token_id=token_id, token_id=token_id,
@ -319,8 +315,9 @@ class SyncRestServlet(RestServlet):
return invited return invited
@defer.inlineCallbacks async def encode_archived(
def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): self, rooms, time_now, token_id, event_fields, event_formatter
):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -341,7 +338,7 @@ class SyncRestServlet(RestServlet):
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = yield self.encode_room( joined[room.room_id] = await self.encode_room(
room, room,
time_now, time_now,
token_id, token_id,
@ -352,8 +349,7 @@ class SyncRestServlet(RestServlet):
return joined return joined
@defer.inlineCallbacks async def encode_room(
def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter self, room, time_now, token_id, joined, only_fields, event_formatter
): ):
""" """
@ -401,8 +397,8 @@ class SyncRestServlet(RestServlet):
event.room_id, event.room_id,
) )
serialized_state = yield serialize(state_events) serialized_state = await serialize(state_events)
serialized_timeline = yield serialize(timeline_events) serialized_timeline = await serialize(timeline_events)
account_data = room.account_data account_data = room.account_data

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -37,13 +35,12 @@ class TagListServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request, user_id, room_id):
def on_GET(self, request, user_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = await self.store.get_tags_for_room(user_id, room_id)
return 200, {"tags": tags} return 200, {"tags": tags}
@ -64,27 +61,25 @@ class TagServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, room_id, tag):
def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, user_id, room_id, tag):
def on_DELETE(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View file

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols() protocols = await self.appservice_handler.get_3pe_protocols()
return 200, protocols return 200, protocols
@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols( protocols = await self.appservice_handler.get_3pe_protocols(
only_protocol=protocol only_protocol=protocol
) )
if protocol in protocols: if protocol in protocols:
@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
) )
@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields
) )

View file

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.") raise AuthError(403, "tokenrefresh is no longer supported.")

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet):
] ]
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled: if not self.hs.config.user_directory_search_enabled:
@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet):
except Exception: except Exception:
raise SynapseError(400, "`search_term` is required field") raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users( results = await self.user_directory_handler.search_users(
user_id, search_term, limit user_id, search_term, limit
) )