mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 10:03:54 +01:00
Add type annotations to SimpleHttpClient (#8372)
This commit is contained in:
parent
6fdf577593
commit
11c9e17738
4 changed files with 143 additions and 61 deletions
1
changelog.d/8372.misc
Normal file
1
changelog.d/8372.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to `SimpleHttpClient`.
|
|
@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
urllib.parse.quote(protocol),
|
urllib.parse.quote(protocol),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
info = await self.get_json(uri, {})
|
info = await self.get_json(uri)
|
||||||
|
|
||||||
if not _is_valid_3pe_metadata(info):
|
if not _is_valid_3pe_metadata(info):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -17,6 +17,18 @@
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import treq
|
import treq
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
|
||||||
from twisted.web.client import Agent, HTTPConnectionPool, readBody
|
from twisted.web.client import Agent, HTTPConnectionPool, readBody
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
from twisted.web.iweb import IResponse
|
||||||
|
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.http import (
|
from synapse.http import (
|
||||||
|
@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
|
||||||
"synapse_http_client_responses", "", ["method", "code"]
|
"synapse_http_client_responses", "", ["method", "code"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# the type of the headers list, to be passed to the t.w.h.Headers.
|
||||||
|
# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
|
||||||
|
# we simplify.
|
||||||
|
RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
|
||||||
|
|
||||||
|
# the value actually has to be a List, but List is invariant so we can't specify that
|
||||||
|
# the entries can either be Lists or bytes.
|
||||||
|
RawHeaderValue = Sequence[Union[str, bytes]]
|
||||||
|
|
||||||
|
# the type of the query params, to be passed into `urlencode`
|
||||||
|
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
|
||||||
|
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
|
||||||
|
|
||||||
|
|
||||||
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
|
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
|
||||||
"""
|
"""
|
||||||
|
@ -285,13 +311,26 @@ class SimpleHttpClient:
|
||||||
ip_blacklist=self._ip_blacklist,
|
ip_blacklist=self._ip_blacklist,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def request(self, method, uri, data=None, headers=None):
|
async def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
uri: str,
|
||||||
|
data: Optional[bytes] = None,
|
||||||
|
headers: Optional[Headers] = None,
|
||||||
|
) -> IResponse:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
method (str): HTTP method to use.
|
method: HTTP method to use.
|
||||||
uri (str): URI to query.
|
uri: URI to query.
|
||||||
data (bytes): Data to send in the request body, if applicable.
|
data: Data to send in the request body, if applicable.
|
||||||
headers (t.w.http_headers.Headers): Request headers.
|
headers: Request headers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response object, once the headers have been read.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RequestTimedOutError if the request times out before the headers are read
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# A small wrapper around self.agent.request() so we can easily attach
|
# A small wrapper around self.agent.request() so we can easily attach
|
||||||
# counters to it
|
# counters to it
|
||||||
|
@ -324,6 +363,8 @@ class SimpleHttpClient:
|
||||||
headers=headers,
|
headers=headers,
|
||||||
**self._extra_treq_args
|
**self._extra_treq_args
|
||||||
)
|
)
|
||||||
|
# we use our own timeout mechanism rather than treq's as a workaround
|
||||||
|
# for https://twistedmatrix.com/trac/ticket/9534.
|
||||||
request_deferred = timeout_deferred(
|
request_deferred = timeout_deferred(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
60,
|
60,
|
||||||
|
@ -353,18 +394,26 @@ class SimpleHttpClient:
|
||||||
set_tag("error_reason", e.args[0])
|
set_tag("error_reason", e.args[0])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def post_urlencoded_get_json(self, uri, args={}, headers=None):
|
async def post_urlencoded_get_json(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
args: Mapping[str, Union[str, List[str]]] = {},
|
||||||
|
headers: Optional[RawHeaders] = None,
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
uri (str):
|
uri: uri to query
|
||||||
args (dict[str, str|List[str]]): query params
|
args: parameters to be url-encoded in the body
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
headers: a map from header name to a list of values for that header
|
||||||
header name to a list of values for that header
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
object: parsed json
|
parsed json
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
HttpResponseException: On a non-2xx HTTP response.
|
HttpResponseException: On a non-2xx HTTP response.
|
||||||
|
|
||||||
ValueError: if the response was not JSON
|
ValueError: if the response was not JSON
|
||||||
|
@ -398,19 +447,24 @@ class SimpleHttpClient:
|
||||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_json_get_json(self, uri, post_json, headers=None):
|
async def post_json_get_json(
|
||||||
|
self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri (str):
|
uri: URI to query.
|
||||||
post_json (object):
|
post_json: request body, to be encoded as json
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
headers: a map from header name to a list of values for that header
|
||||||
header name to a list of values for that header
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
object: parsed json
|
parsed json
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
HttpResponseException: On a non-2xx HTTP response.
|
HttpResponseException: On a non-2xx HTTP response.
|
||||||
|
|
||||||
ValueError: if the response was not JSON
|
ValueError: if the response was not JSON
|
||||||
|
@ -440,21 +494,22 @@ class SimpleHttpClient:
|
||||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_json(self, uri, args={}, headers=None):
|
async def get_json(
|
||||||
""" Gets some json from the given URI.
|
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Gets some json from the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri (str): The URI to request, not including query parameters
|
uri: The URI to request, not including query parameters
|
||||||
args (dict): A dictionary used to create query strings, defaults to
|
args: A dictionary used to create query string
|
||||||
None.
|
headers: a map from header name to a list of values for that header
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
|
||||||
and *not* a string.
|
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
|
||||||
header name to a list of values for that header
|
|
||||||
Returns:
|
Returns:
|
||||||
Succeeds when we get *any* 2xx HTTP response, with the
|
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
|
||||||
HTTP body as JSON.
|
|
||||||
Raises:
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
HttpResponseException On a non-2xx HTTP response.
|
HttpResponseException On a non-2xx HTTP response.
|
||||||
|
|
||||||
ValueError: if the response was not JSON
|
ValueError: if the response was not JSON
|
||||||
|
@ -466,22 +521,27 @@ class SimpleHttpClient:
|
||||||
body = await self.get_raw(uri, args, headers=headers)
|
body = await self.get_raw(uri, args, headers=headers)
|
||||||
return json_decoder.decode(body.decode("utf-8"))
|
return json_decoder.decode(body.decode("utf-8"))
|
||||||
|
|
||||||
async def put_json(self, uri, json_body, args={}, headers=None):
|
async def put_json(
|
||||||
""" Puts some json to the given URI.
|
self,
|
||||||
|
uri: str,
|
||||||
|
json_body: Any,
|
||||||
|
args: QueryParams = {},
|
||||||
|
headers: RawHeaders = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Puts some json to the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri (str): The URI to request, not including query parameters
|
uri: The URI to request, not including query parameters
|
||||||
json_body (dict): The JSON to put in the HTTP body,
|
json_body: The JSON to put in the HTTP body,
|
||||||
args (dict): A dictionary used to create query strings, defaults to
|
args: A dictionary used to create query strings
|
||||||
None.
|
headers: a map from header name to a list of values for that header
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
|
||||||
and *not* a string.
|
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
|
||||||
header name to a list of values for that header
|
|
||||||
Returns:
|
Returns:
|
||||||
Succeeds when we get *any* 2xx HTTP response, with the
|
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
|
||||||
HTTP body as JSON.
|
|
||||||
Raises:
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
HttpResponseException On a non-2xx HTTP response.
|
HttpResponseException On a non-2xx HTTP response.
|
||||||
|
|
||||||
ValueError: if the response was not JSON
|
ValueError: if the response was not JSON
|
||||||
|
@ -513,21 +573,23 @@ class SimpleHttpClient:
|
||||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_raw(self, uri, args={}, headers=None):
|
async def get_raw(
|
||||||
""" Gets raw text from the given URI.
|
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
|
||||||
|
) -> bytes:
|
||||||
|
"""Gets raw text from the given URI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri (str): The URI to request, not including query parameters
|
uri: The URI to request, not including query parameters
|
||||||
args (dict): A dictionary used to create query strings, defaults to
|
args: A dictionary used to create query strings
|
||||||
None.
|
headers: a map from header name to a list of values for that header
|
||||||
**Note**: The value of each key is assumed to be an iterable
|
|
||||||
and *not* a string.
|
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
|
||||||
header name to a list of values for that header
|
|
||||||
Returns:
|
Returns:
|
||||||
Succeeds when we get *any* 2xx HTTP response, with the
|
Succeeds when we get a 2xx HTTP response, with the
|
||||||
HTTP body as bytes.
|
HTTP body as bytes.
|
||||||
Raises:
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
HttpResponseException on a non-2xx HTTP response.
|
HttpResponseException on a non-2xx HTTP response.
|
||||||
"""
|
"""
|
||||||
if len(args):
|
if len(args):
|
||||||
|
@ -552,16 +614,29 @@ class SimpleHttpClient:
|
||||||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
|
||||||
# The two should be factored out.
|
# The two should be factored out.
|
||||||
|
|
||||||
async def get_file(self, url, output_stream, max_size=None, headers=None):
|
async def get_file(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
output_stream: BinaryIO,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
headers: Optional[RawHeaders] = None,
|
||||||
|
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
|
||||||
"""GETs a file from a given URL
|
"""GETs a file from a given URL
|
||||||
Args:
|
Args:
|
||||||
url (str): The URL to GET
|
url: The URL to GET
|
||||||
output_stream (file): File to write the response body to.
|
output_stream: File to write the response body to.
|
||||||
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
|
headers: A map from header name to a list of values for that header
|
||||||
header name to a list of values for that header
|
|
||||||
Returns:
|
Returns:
|
||||||
A (int,dict,string,int) tuple of the file length, dict of the response
|
A tuple of the file length, dict of the response
|
||||||
headers, absolute URI of the response and HTTP response code.
|
headers, absolute URI of the response and HTTP response code.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RequestTimedOutException: if there is a timeout before the response headers
|
||||||
|
are received. Note there is currently no timeout on reading the response
|
||||||
|
body.
|
||||||
|
|
||||||
|
SynapseError: if the response is not a 2xx, the remote file is too large, or
|
||||||
|
another exception happens during the download.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
actual_headers = {b"User-Agent": [self.user_agent]}
|
actual_headers = {b"User-Agent": [self.user_agent]}
|
||||||
|
|
|
@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||||
raise OEmbedError() from e
|
raise OEmbedError() from e
|
||||||
|
|
||||||
async def _download_url(self, url, user):
|
async def _download_url(self, url: str, user):
|
||||||
# TODO: we should probably honour robots.txt... except in practice
|
# TODO: we should probably honour robots.txt... except in practice
|
||||||
# we're most likely being explicitly triggered by a human rather than a
|
# we're most likely being explicitly triggered by a human rather than a
|
||||||
# bot, so are we really a robot?
|
# bot, so are we really a robot?
|
||||||
|
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||||
|
|
||||||
# If this URL can be accessed via oEmbed, use that instead.
|
# If this URL can be accessed via oEmbed, use that instead.
|
||||||
url_to_download = url
|
url_to_download = url # type: Optional[str]
|
||||||
oembed_url = self._get_oembed_url(url)
|
oembed_url = self._get_oembed_url(url)
|
||||||
if oembed_url:
|
if oembed_url:
|
||||||
# The result might be a new URL to download, or it might be HTML content.
|
# The result might be a new URL to download, or it might be HTML content.
|
||||||
|
@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
# FIXME: we should calculate a proper expiration based on the
|
# FIXME: we should calculate a proper expiration based on the
|
||||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||||
expires = ONE_HOUR
|
expires = ONE_HOUR
|
||||||
etag = headers["ETag"][0] if "ETag" in headers else None
|
etag = (
|
||||||
|
headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
html_bytes = oembed_result.html.encode("utf-8") # type: ignore
|
# we can only get here if we did an oembed request and have an oembed_result.html
|
||||||
|
assert oembed_result.html is not None
|
||||||
|
assert oembed_url is not None
|
||||||
|
|
||||||
|
html_bytes = oembed_result.html.encode("utf-8")
|
||||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||||
f.write(html_bytes)
|
f.write(html_bytes)
|
||||||
await finish()
|
await finish()
|
||||||
|
|
Loading…
Reference in a new issue