Convert federation client to async/await. (#7975)

This commit is contained in:
Patrick Cloke 2020-07-30 08:01:33 -04:00 committed by GitHub
parent 4cce8ef74e
commit c978f6c451
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 209 additions and 221 deletions

1
changelog.d/7975.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd):
@defer.inlineCallbacks
def _do_event_stream(self, timeout):
res = yield self.http_client.get_json(
self._url() + "/events",
{
"access_token": self._tok(),
"timeout": str(timeout),
"from": self.event_stream_token,
},
res = yield defer.ensureDeferred(
self.http_client.get_json(
self._url() + "/events",
{
"access_token": self._tok(),
"timeout": str(timeout),
"from": self.event_stream_token,
},
)
)
print(json.dumps(res, indent=4))

View file

@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
query_response = yield self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
"server_keys": {
server_name: {
key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
query_response = yield defer.ensureDeferred(
self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
"server_keys": {
server_name: {
key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, server_keys in keys_to_fetch.items()
}
for server_name, server_keys in keys_to_fetch.items()
}
},
},
)
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon
@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
response = yield defer.ensureDeferred(
self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
)
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve

View file

@ -135,7 +135,7 @@ class FederationClient(FederationBase):
and try the request anyway.
Returns:
a Deferred which will eventually yield a JSON object from the
a Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels(query_type).inc()
@ -157,7 +157,7 @@ class FederationClient(FederationBase):
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
an Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels("client_device_keys").inc()
@ -180,7 +180,7 @@ class FederationClient(FederationBase):
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
an Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
@ -900,7 +900,7 @@ class FederationClient(FederationBase):
party instance
Returns:
Deferred[Dict[str, Any]]: The response from the remote server, or None if
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises:

View file

@ -288,8 +288,7 @@ class FederationSender(object):
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu, order)
@defer.inlineCallbacks
def send_read_receipt(self, receipt: ReadReceipt):
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
"""Send a RR to any other servers in the room
Args:
@ -330,9 +329,7 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = yield defer.ensureDeferred(
self.state.get_current_hosts_in_room(room_id)
)
domains = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
@ -387,8 +384,7 @@ class FederationSender(object):
queue.flush_read_receipts_for_room(room_id)
@preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states: List[UserPresenceState]):
async def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
@ -423,7 +419,7 @@ class FederationSender(object):
if not states_map:
break
yield self._process_presence_inner(list(states_map.values()))
await self._process_presence_inner(list(states_map.values()))
except Exception:
logger.exception("Error sending presence states to servers")
finally:
@ -450,14 +446,11 @@ class FederationSender(object):
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states: List[UserPresenceState]):
async def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
"""
hosts_and_states = yield defer.ensureDeferred(
get_interested_remotes(self.store, states, self.state)
)
hosts_and_states = await get_interested_remotes(self.store, states, self.state)
for destinations, states in hosts_and_states:
for destination in destinations:

View file

@ -18,8 +18,6 @@ import logging
import urllib
from typing import Any, Dict, Optional
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import (
@ -51,7 +49,7 @@ class TransportLayerClient(object):
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
Awaitable: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
@ -75,7 +73,7 @@ class TransportLayerClient(object):
giving up. None indicates no timeout.
Returns:
Deferred: Results in a dict received from the remote homeserver.
Awaitable: Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
@ -96,7 +94,7 @@ class TransportLayerClient(object):
limit (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
Awaitable: Results in a dict received from the remote homeserver.
"""
logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@ -118,16 +116,15 @@ class TransportLayerClient(object):
destination, path=path, args=args, try_trailing_slash_on_400=True
)
@defer.inlineCallbacks
@log_function
def send_transaction(self, transaction, json_data_callback=None):
async def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to its destination
Args:
transaction (Transaction)
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
Fails with ``HTTPRequestException`` if we get an HTTP response
@ -154,7 +151,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/send/%s", transaction.transaction_id)
response = yield self.client.put_json(
response = await self.client.put_json(
transaction.destination,
path=path,
data=json_data,
@ -166,14 +163,13 @@ class TransportLayerClient(object):
return response
@defer.inlineCallbacks
@log_function
def make_query(
async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
):
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
content = await self.client.get_json(
destination=destination,
path=path,
args=args,
@ -184,9 +180,10 @@ class TransportLayerClient(object):
return content
@defer.inlineCallbacks
@log_function
def make_membership_event(self, destination, room_id, user_id, membership, params):
async def make_membership_event(
self, destination, room_id, user_id, membership, params
):
"""Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs.
@ -200,7 +197,7 @@ class TransportLayerClient(object):
request.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body (ie, the new event).
Fails with ``HTTPRequestException`` if we get an HTTP response
@ -231,7 +228,7 @@ class TransportLayerClient(object):
ignore_backoff = True
retry_on_dns_fail = True
content = yield self.client.get_json(
content = await self.client.get_json(
destination=destination,
path=path,
args=params,
@ -242,34 +239,31 @@ class TransportLayerClient(object):
return content
@defer.inlineCallbacks
@log_function
def send_join_v1(self, destination, room_id, event_id, content):
async def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination, path=path, data=content
)
return response
@defer.inlineCallbacks
@log_function
def send_join_v2(self, destination, room_id, event_id, content):
async def send_join_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination, path=path, data=content
)
return response
@defer.inlineCallbacks
@log_function
def send_leave_v1(self, destination, room_id, event_id, content):
async def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination,
path=path,
data=content,
@ -282,12 +276,11 @@ class TransportLayerClient(object):
return response
@defer.inlineCallbacks
@log_function
def send_leave_v2(self, destination, room_id, event_id, content):
async def send_leave_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination,
path=path,
data=content,
@ -300,31 +293,28 @@ class TransportLayerClient(object):
return response
@defer.inlineCallbacks
@log_function
def send_invite_v1(self, destination, room_id, event_id, content):
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
return response
@defer.inlineCallbacks
@log_function
def send_invite_v2(self, destination, room_id, event_id, content):
async def send_invite_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
return response
@defer.inlineCallbacks
@log_function
def get_public_rooms(
async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
@ -355,7 +345,7 @@ class TransportLayerClient(object):
data["filter"] = search_filter
try:
response = yield self.client.post_json(
response = await self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
except HttpResponseException as e:
@ -381,7 +371,7 @@ class TransportLayerClient(object):
args["since"] = [since_token]
try:
response = yield self.client.get_json(
response = await self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
except HttpResponseException as e:
@ -396,29 +386,26 @@ class TransportLayerClient(object):
return response
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
async def exchange_third_party_invite(self, destination, room_id, event_dict):
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
response = yield self.client.put_json(
response = await self.client.put_json(
destination=destination, path=path, data=event_dict
)
return response
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
async def get_event_auth(self, destination, room_id, event_id):
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(destination=destination, path=path)
content = await self.client.get_json(destination=destination, path=path)
return content
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content, timeout):
async def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote
server.
@ -453,14 +440,13 @@ class TransportLayerClient(object):
"""
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
content = await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
return content
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
async def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
@ -493,14 +479,13 @@ class TransportLayerClient(object):
"""
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
content = await self.client.get_json(
destination=destination, path=path, timeout=timeout
)
return content
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content, timeout):
async def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
@ -532,14 +517,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
content = await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
return content
@defer.inlineCallbacks
@log_function
def get_missing_events(
async def get_missing_events(
self,
destination,
room_id,
@ -551,7 +535,7 @@ class TransportLayerClient(object):
):
path = _create_v1_path("/get_missing_events/%s", room_id)
content = yield self.client.post_json(
content = await self.client.post_json(
destination=destination,
path=path,
data={

View file

@ -23,39 +23,32 @@ logger = logging.getLogger(__name__)
def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
"""Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
def f(self, group_id, *args, **kwargs):
async def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
return await getattr(self.groups_server_handler, func_name)(
group_id, *args, **kwargs
)
else:
destination = get_domain_from_id(group_id)
d = getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs
)
# Capture errors returned by the remote homeserver and
# re-throw specific errors as SynapseErrors. This is so
# when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
try:
return await getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs
)
except HttpResponseException as e:
# Capture errors returned by the remote homeserver and
# re-throw specific errors as SynapseErrors. This is so
# when the remote end responds with things like 403 Not
# In Group, we can communicate that to the client instead
# of a 500.
raise e.to_synapse_error()
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
return f

View file

@ -121,8 +121,7 @@ class MatrixFederationRequest(object):
return self.json
@defer.inlineCallbacks
def _handle_json_response(reactor, timeout_sec, request, response):
async def _handle_json_response(reactor, timeout_sec, request, response):
"""
Reads the JSON body of a response, with a timeout
@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
d = treq.json_content(response)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
body = await make_deferred_yieldable(d)
except TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response", request.txn_id, request.destination,
@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object):
self._cooperator = Cooperator(scheduler=schedule)
@defer.inlineCallbacks
def _send_request_with_optional_trailing_slash(
async def _send_request_with_optional_trailing_slash(
self, request, try_trailing_slash_on_400=False, **send_request_args
):
"""Wrapper for _send_request which can optionally retry the request
@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object):
(except 429).
Returns:
Deferred[Dict]: Parsed JSON response body.
Dict: Parsed JSON response body.
"""
try:
response = yield self._send_request(request, **send_request_args)
response = await self._send_request(request, **send_request_args)
except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash
@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash")
request.path += "/"
response = yield self._send_request(request, **send_request_args)
response = await self._send_request(request, **send_request_args)
return response
@defer.inlineCallbacks
def _send_request(
async def _send_request(
self,
request,
retry_on_dns_fail=True,
@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object):
backoff_on_404 (bool): Back off if we get a 404
Returns:
Deferred[twisted.web.client.Response]: resolves with the HTTP
twisted.web.client.Response: resolves with the HTTP
response object on success.
Raises:
@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object):
):
raise FederationDeniedError(request.destination)
limiter = yield synapse.util.retryutils.get_retry_limiter(
limiter = await synapse.util.retryutils.get_retry_limiter(
request.destination,
self.clock,
self._store,
@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object):
reactor=self.reactor,
)
response = yield request_deferred
response = await request_deferred
except TimeoutError as e:
raise RequestSendFailed(e, can_retry=True) from e
except DNSLookupError as e:
@ -474,7 +471,7 @@ class MatrixFederationHttpClient(object):
)
try:
body = yield make_deferred_yieldable(d)
body = await make_deferred_yieldable(d)
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
@ -528,7 +525,7 @@ class MatrixFederationHttpClient(object):
delay,
)
yield self.clock.sleep(delay)
await self.clock.sleep(delay)
retries_left -= 1
else:
raise
@ -591,8 +588,7 @@ class MatrixFederationHttpClient(object):
)
return auth_headers
@defer.inlineCallbacks
def put_json(
async def put_json(
self,
destination,
path,
@ -636,7 +632,7 @@ class MatrixFederationHttpClient(object):
enabled.
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -658,7 +654,7 @@ class MatrixFederationHttpClient(object):
json=data,
)
response = yield self._send_request_with_optional_trailing_slash(
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
backoff_on_404=backoff_on_404,
@ -667,14 +663,13 @@ class MatrixFederationHttpClient(object):
timeout=timeout,
)
body = yield _handle_json_response(
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
)
return body
@defer.inlineCallbacks
def post_json(
async def post_json(
self,
destination,
path,
@ -707,7 +702,7 @@ class MatrixFederationHttpClient(object):
args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -725,7 +720,7 @@ class MatrixFederationHttpClient(object):
method="POST", destination=destination, path=path, query=args, json=data
)
response = yield self._send_request(
response = await self._send_request(
request,
long_retries=long_retries,
timeout=timeout,
@ -737,13 +732,12 @@ class MatrixFederationHttpClient(object):
else:
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
body = await _handle_json_response(
self.reactor, _sec_timeout, request, response
)
return body
@defer.inlineCallbacks
def get_json(
async def get_json(
self,
destination,
path,
@ -775,7 +769,7 @@ class MatrixFederationHttpClient(object):
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -792,7 +786,7 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request_with_optional_trailing_slash(
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
backoff_on_404=False,
@ -801,14 +795,13 @@ class MatrixFederationHttpClient(object):
timeout=timeout,
)
body = yield _handle_json_response(
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
)
return body
@defer.inlineCallbacks
def delete_json(
async def delete_json(
self,
destination,
path,
@ -836,7 +829,7 @@ class MatrixFederationHttpClient(object):
args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@ -853,20 +846,19 @@ class MatrixFederationHttpClient(object):
method="DELETE", destination=destination, path=path, query=args
)
response = yield self._send_request(
response = await self._send_request(
request,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
)
body = yield _handle_json_response(
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
)
return body
@defer.inlineCallbacks
def get_file(
async def get_file(
self,
destination,
path,
@ -886,7 +878,7 @@ class MatrixFederationHttpClient(object):
and try the request anyway.
Returns:
Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
tuple[int, dict]: Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@ -903,7 +895,7 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request(
response = await self._send_request(
request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
)
@ -912,7 +904,7 @@ class MatrixFederationHttpClient(object):
try:
d = _readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
length = await make_deferred_yieldable(d)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",

View file

@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
}
persp_deferred = defer.Deferred()
@defer.inlineCallbacks
def get_perspectives(**kwargs):
async def get_perspectives(**kwargs):
self.assertEquals(current_context().request, "11")
with PreserveLoggingContext():
yield persp_deferred
await persp_deferred
return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
def get_json(destination, path, **kwargs):
async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
def post_json(destination, path, data, **kwargs):
async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")

View file

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests import unittest
from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(

View file

@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction.return_value = defer.succeed({})
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
self.successResultOf(sender.send_read_receipt(receipt))
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction.return_value = defer.succeed({})
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
)
self.successResultOf(sender.send_read_receipt(receipt))
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
)
self.successResultOf(sender.send_read_receipt(receipt))
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
mock_send_transaction.assert_not_called()

View file

@ -16,8 +16,6 @@
from mock import Mock
from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.api.constants import EventTypes
@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
from synapse.types import RoomAlias, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self):
self.mock_federation.make_query.return_value = defer.succeed(
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)

View file

@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_other_name(self):
self.mock_federation.make_query.return_value = defer.succeed(
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)

View file

@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
@defer.inlineCallbacks
def do_request():
with LoggingContext("one") as context:
fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
)
# Nothing happened yet
self.assertNoResult(fetch_d)
@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
"""
If the DNS lookup returns an error, it will bubble up.
"""
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
)
self.pump()
f = self.failureResultOf(d)
@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
def test_client_connection_refused(self):
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
)
self.pump()
@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is not connected and is timed out, it'll give a
ConnectingCancelledError or TimeoutError.
"""
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
)
self.pump()
@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
)
self.pump()
@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a blacklisted IPv4 address
# ------------------------------------------------------
# Make the request
d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
# Nothing happened yet
self.assertNoResult(d)
@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
# Try making a POST request to a blacklisted IPv6 address
# -------------------------------------------------------
# Make the request
d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
)
# Nothing has happened yet
self.assertNoResult(d)
@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
# Try making a GET request to a non-blacklisted IPv4 address
# ----------------------------------------------------------
# Make the request
d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
# Nothing has happened yet
self.assertNoResult(d)
@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
request = MatrixFederationRequest(
method="GET", destination="testserv:8008", path="foo/bar"
)
d = self.cl._send_request(request, timeout=10000)
d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
self.pump()
@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
If the HTTP request is connected, but gets no response before being
timed out, it'll give a ResponseNeverReceived.
"""
d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
d = defer.ensureDeferred(
self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
)
self.pump()
@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
requiring a trailing slash. We need to retry the request with a
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
"""
d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
)
# Send the request
self.pump()
@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
See test_client_requires_trailing_slashes() for context.
"""
d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
)
# Send the request
self.pump()
@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
self.failureResultOf(d)
def test_client_sends_body(self):
self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
defer.ensureDeferred(
self.cl.post_json(
"testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
)
)
self.pump()
@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
def test_closes_connection(self):
"""Check that the client closes unused HTTP connections"""
d = self.cl.get_json("testserv:8008", "foo/bar")
d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
self.pump()

View file

@ -16,8 +16,6 @@ import logging
from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__)
@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
self.make_worker_hs(
"synapse.app.federation_sender",
{

View file

@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.fetches = []
def get_file(destination, path, output_stream, args=None, max_size=None):
async def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
d = Deferred()
d.addCallback(write_to)
self.fetches.append((d, destination, path, args))
return make_deferred_yieldable(d)
return await make_deferred_yieldable(d)
client = Mock()
client.get_file = get_file

View file

@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
def get_json(destination, path, ignore_backoff=False, **kwargs):
async def get_json(destination, path, ignore_backoff=False, **kwargs):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
def post_json(destination, path, data):
async def post_json(destination, path, data):
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path, "/_matrix/key/v2/query",

View file

@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
prev_events that said event references.
"""
def post_json(destination, path, data, headers=None, timeout=0):
async def post_json(destination, path, data, headers=None, timeout=0):
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}