0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

pass a Site into make_request

This commit is contained in:
Richard van der Hoff 2020-11-13 22:39:09 +00:00
parent d3523e3e97
commit 9debe657a3
4 changed files with 68 additions and 20 deletions

View file

@ -27,7 +27,7 @@ from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests.server import make_request, render from tests.server import FakeSite, make_request, render
@attr.s @attr.s
@ -53,7 +53,11 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"POST",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.site.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
@ -126,7 +130,11 @@ class RestHelper:
data.update(extra_data) data.update(extra_data)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(data).encode("utf8"),
) )
render(request, self.site.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
@ -159,7 +167,11 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.site.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
@ -211,7 +223,9 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
request, channel = make_request(self.hs.get_reactor(), method, path, content) request, channel = make_request(
self.hs.get_reactor(), self.site, method, path, content
)
render(request, self.site.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
@ -297,7 +311,12 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok self.hs.get_reactor(),
FakeSite(resource),
"POST",
path,
content=image_data,
access_token=tok,
) )
request.requestHeaders.addRawHeader( request.requestHeaders.addRawHeader(
b"Content-Length", str(image_length).encode("UTF-8") b"Content-Length", str(image_length).encode("UTF-8")

View file

@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote from twisted.web.http import unquote
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -128,9 +129,21 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
def getResourceFor(self, request):
return self._resource
def make_request( def make_request(
reactor, reactor,
site: Site,
method, method,
path, path,
content=b"", content=b"",
@ -145,6 +158,8 @@ def make_request(
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
Args: Args:
site: The twisted Site to associate with the Channel
method (bytes/unicode): The HTTP request method ("verb"). method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such). escaped UTF-8 & spaces and such).
@ -181,7 +196,6 @@ def make_request(
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf8") content = content.encode("utf8")
site = FakeSite()
channel = FakeChannel(site, reactor) channel = FakeChannel(site, reactor)
req = request(channel) req = request(channel)

View file

@ -26,6 +26,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ( from tests.server import (
FakeSite,
ThreadedMemoryReactorClock, ThreadedMemoryReactorClock,
make_request, make_request,
render, render,
@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
) )
request, channel = make_request( request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
) )
render(request, res, self.reactor) render(request, res, self.reactor)
@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path): def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response.""" """Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource. # Create a site and query for the resource.
site = SynapseSite( site = SynapseSite(
"test", "test",
@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
self.resource, self.resource,
"1.0", "1.0",
) )
request, channel = make_request(
self.reactor, site, method, path, shorthand=False
)
request.prepath = [] # This doesn't get set properly by make_request.
request.site = site request.site = site
resource = site.getResourceFor(request) resource = site.getResourceFor(request)
@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")

View file

@ -434,6 +434,7 @@ class HomeserverTestCase(TestCase):
return make_request( return make_request(
self.reactor, self.reactor,
self.site,
method, method,
path, path,
content, content,