Add a custom_headers param to make_request (#8760)

Some tests want to set some custom HTTP request headers, so provide a way to do
that before calling requestReceived().
This commit is contained in:
Richard van der Hoff 2020-11-16 14:45:22 +00:00 committed by GitHub
parent f1de4bb58b
commit ebc405446e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 11 deletions

1
changelog.d/8760.misc Normal file
View file

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

View file

@ -296,10 +296,12 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok self.hs.get_reactor(),
) "POST",
request.requestHeaders.addRawHeader( path,
b"Content-Length", str(image_length).encode("UTF-8") content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
) )
request.render(resource) request.render(resource)
self.hs.get_reactor().pump([100]) self.hs.get_reactor().pump([100])

View file

@ -2,7 +2,7 @@ import json
import logging import logging
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import Callable from typing import Callable, Iterable, Optional, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -139,6 +139,9 @@ def make_request(
shorthand=True, shorthand=True,
federation_auth_origin=None, federation_auth_origin=None,
content_is_form=False, content_is_form=False,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
@ -157,6 +160,8 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
custom_headers: (name, value) pairs to add as request headers
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
@ -211,6 +216,10 @@ def make_request(
# Assume the body is JSON # Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
if custom_headers:
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
return req, channel return req, channel

View file

@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from tests import unittest from tests import unittest
from tests.server import make_request
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
@ -408,17 +409,17 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
# Advance to a known time # Advance to a known time
self.reactor.advance(123456 - self.reactor.seconds()) self.reactor.advance(123456 - self.reactor.seconds())
request, channel = self.make_request( headers1 = {b"User-Agent": b"Mozzila pizza"}
headers1.update(headers)
request, channel = make_request(
self.reactor,
"GET", "GET",
"/_matrix/client/r0/admin/users/" + self.user_id, "/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token, access_token=access_token,
custom_headers=headers1.items(),
**make_request_args, **make_request_args,
) )
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
# Add the optional headers
for h, v in headers.items():
request.requestHeaders.addRawHeader(h, v)
self.render(request) self.render(request)
# Advance so the save loop occurs # Advance so the save loop occurs