forked from MirrorHub/synapse
Remove some boilerplate in tests (#4156)
This commit is contained in:
parent
0f5e51f726
commit
e62f7f17b3
11 changed files with 163 additions and 217 deletions
1
changelog.d/4156.misc
Normal file
1
changelog.d/4156.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
HTTP tests have been refactored to contain less boilerplate.
|
|
@ -19,24 +19,17 @@ import json
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
|
||||||
from synapse.rest.client.v1.admin import register_servlets
|
from synapse.rest.client.v1.admin import register_servlets
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import (
|
|
||||||
ThreadedMemoryReactorClock,
|
|
||||||
make_request,
|
|
||||||
render,
|
|
||||||
setup_test_homeserver,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserRegisterTestCase(unittest.TestCase):
|
class UserRegisterTestCase(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
|
||||||
|
servlets = [register_servlets]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
self.clock = ThreadedMemoryReactorClock()
|
|
||||||
self.hs_clock = Clock(self.clock)
|
|
||||||
self.url = "/_matrix/client/r0/admin/register"
|
self.url = "/_matrix/client/r0/admin/register"
|
||||||
|
|
||||||
self.registration_handler = Mock()
|
self.registration_handler = Mock()
|
||||||
|
@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.secrets = Mock()
|
self.secrets = Mock()
|
||||||
|
|
||||||
self.hs = setup_test_homeserver(
|
self.hs = self.setup_test_homeserver()
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
|
||||||
)
|
|
||||||
|
|
||||||
self.hs.config.registration_shared_secret = u"shared"
|
self.hs.config.registration_shared_secret = u"shared"
|
||||||
|
|
||||||
self.hs.get_media_repository = Mock()
|
self.hs.get_media_repository = Mock()
|
||||||
self.hs.get_deactivate_account_handler = Mock()
|
self.hs.get_deactivate_account_handler = Mock()
|
||||||
|
|
||||||
self.resource = JsonResource(self.hs)
|
return self.hs
|
||||||
register_servlets(self.hs, self.resource)
|
|
||||||
|
|
||||||
def test_disabled(self):
|
def test_disabled(self):
|
||||||
"""
|
"""
|
||||||
|
@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
self.hs.config.registration_shared_secret = None
|
self.hs.config.registration_shared_secret = None
|
||||||
|
|
||||||
request, channel = make_request("POST", self.url, b'{}')
|
request, channel = self.make_request("POST", self.url, b'{}')
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.hs.get_secrets = Mock(return_value=secrets)
|
self.hs.get_secrets = Mock(return_value=secrets)
|
||||||
|
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.json_body, {"nonce": "abcd"})
|
self.assertEqual(channel.json_body, {"nonce": "abcd"})
|
||||||
|
|
||||||
|
@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
Calling GET on the endpoint will return a randomised nonce, which will
|
Calling GET on the endpoint will return a randomised nonce, which will
|
||||||
only last for SALT_TIMEOUT (60s).
|
only last for SALT_TIMEOUT (60s).
|
||||||
"""
|
"""
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
nonce = channel.json_body["nonce"]
|
nonce = channel.json_body["nonce"]
|
||||||
|
|
||||||
# 59 seconds
|
# 59 seconds
|
||||||
self.clock.advance(59)
|
self.reactor.advance(59)
|
||||||
|
|
||||||
body = json.dumps({"nonce": nonce})
|
body = json.dumps({"nonce": nonce})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||||
|
|
||||||
# 61 seconds
|
# 61 seconds
|
||||||
self.clock.advance(2)
|
self.reactor.advance(2)
|
||||||
|
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||||
|
@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Only the provided nonce can be used, as it's checked in the MAC.
|
Only the provided nonce can be used, as it's checked in the MAC.
|
||||||
"""
|
"""
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
nonce = channel.json_body["nonce"]
|
nonce = channel.json_body["nonce"]
|
||||||
|
|
||||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||||
|
@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"mac": want_mac,
|
"mac": want_mac,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("HMAC incorrect", channel.json_body["error"])
|
self.assertEqual("HMAC incorrect", channel.json_body["error"])
|
||||||
|
@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
When the correct nonce is provided, and the right key is provided, the
|
When the correct nonce is provided, and the right key is provided, the
|
||||||
user is registered.
|
user is registered.
|
||||||
"""
|
"""
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
nonce = channel.json_body["nonce"]
|
nonce = channel.json_body["nonce"]
|
||||||
|
|
||||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||||
|
@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"mac": want_mac,
|
"mac": want_mac,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||||
|
@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
A valid unrecognised nonce.
|
A valid unrecognised nonce.
|
||||||
"""
|
"""
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
nonce = channel.json_body["nonce"]
|
nonce = channel.json_body["nonce"]
|
||||||
|
|
||||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||||
|
@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"mac": want_mac,
|
"mac": want_mac,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||||
|
|
||||||
# Now, try and reuse it
|
# Now, try and reuse it
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||||
|
@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def nonce():
|
def nonce():
|
||||||
request, channel = make_request("GET", self.url)
|
request, channel = self.make_request("GET", self.url)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
return channel.json_body["nonce"]
|
return channel.json_body["nonce"]
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# Must be present
|
# Must be present
|
||||||
body = json.dumps({})
|
body = json.dumps({})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('nonce must be specified', channel.json_body["error"])
|
self.assertEqual('nonce must be specified', channel.json_body["error"])
|
||||||
|
@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# Must be present
|
# Must be present
|
||||||
body = json.dumps({"nonce": nonce()})
|
body = json.dumps({"nonce": nonce()})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||||
|
|
||||||
# Must be a string
|
# Must be a string
|
||||||
body = json.dumps({"nonce": nonce(), "username": 1234})
|
body = json.dumps({"nonce": nonce(), "username": 1234})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||||
|
|
||||||
# Must not have null bytes
|
# Must not have null bytes
|
||||||
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
|
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||||
|
|
||||||
# Must not have null bytes
|
# Must not have null bytes
|
||||||
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
|
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||||
|
@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# Must be present
|
# Must be present
|
||||||
body = json.dumps({"nonce": nonce(), "username": "a"})
|
body = json.dumps({"nonce": nonce(), "username": "a"})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('password must be specified', channel.json_body["error"])
|
self.assertEqual('password must be specified', channel.json_body["error"])
|
||||||
|
|
||||||
# Must be a string
|
# Must be a string
|
||||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
|
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||||
|
@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase):
|
||||||
body = json.dumps(
|
body = json.dumps(
|
||||||
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
|
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
|
||||||
)
|
)
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||||
|
|
||||||
# Super long
|
# Super long
|
||||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
|
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
|
||||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||||
|
|
|
@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
handlers = Mock(registration_handler=self.registration_handler)
|
handlers = Mock(registration_handler=self.registration_handler)
|
||||||
self.clock = MemoryReactorClock()
|
self.reactor = MemoryReactorClock()
|
||||||
self.hs_clock = Clock(self.clock)
|
self.hs_clock = Clock(self.reactor)
|
||||||
|
|
||||||
self.hs = self.hs = setup_test_homeserver(
|
self.hs = self.hs = setup_test_homeserver(
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
|
||||||
)
|
)
|
||||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||||
self.hs.get_handlers = Mock(return_value=handlers)
|
self.hs.get_handlers = Mock(return_value=handlers)
|
||||||
|
@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase):
|
||||||
return_value=(user_id, token)
|
return_value=(user_id, token)
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = make_request(b"POST", url, request_data)
|
request, channel = make_request(self.reactor, b"POST", url, request_data)
|
||||||
render(request, res, self.clock)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200")
|
self.assertEquals(channel.result["code"], b"200")
|
||||||
|
|
||||||
|
|
|
@ -169,7 +169,7 @@ class RestHelper(object):
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
request, channel = make_request(
|
request, channel = make_request(
|
||||||
"POST", path, json.dumps(content).encode('utf8')
|
self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.hs.get_reactor())
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
|
|
||||||
|
@ -217,7 +217,9 @@ class RestHelper(object):
|
||||||
|
|
||||||
data = {"membership": membership}
|
data = {"membership": membership}
|
||||||
|
|
||||||
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
|
request, channel = make_request(
|
||||||
|
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
|
||||||
|
)
|
||||||
|
|
||||||
render(request, self.resource, self.hs.get_reactor())
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
|
|
||||||
|
@ -228,18 +230,6 @@ class RestHelper(object):
|
||||||
|
|
||||||
self.auth_user_id = temp_id
|
self.auth_user_id = temp_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def register(self, user_id):
|
|
||||||
(code, response) = yield self.mock_resource.trigger(
|
|
||||||
"POST",
|
|
||||||
"/_matrix/client/r0/register",
|
|
||||||
json.dumps(
|
|
||||||
{"user": user_id, "password": "test", "type": "m.login.password"}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.assertEquals(200, code)
|
|
||||||
defer.returnValue(response)
|
|
||||||
|
|
||||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
txn_id = "m%s" % (str(time.time()))
|
txn_id = "m%s" % (str(time.time()))
|
||||||
|
@ -251,7 +241,9 @@ class RestHelper(object):
|
||||||
if tok:
|
if tok:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
|
request, channel = make_request(
|
||||||
|
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
|
||||||
|
)
|
||||||
render(request, self.resource, self.hs.get_reactor())
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
|
|
||||||
assert int(channel.result["code"]) == expect_code, (
|
assert int(channel.result["code"]) == expect_code, (
|
||||||
|
|
|
@ -13,84 +13,47 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import synapse.types
|
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.http.server import JsonResource
|
|
||||||
from synapse.rest.client.v2_alpha import filter
|
from synapse.rest.client.v2_alpha import filter
|
||||||
from synapse.types import UserID
|
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import (
|
|
||||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
|
||||||
make_request,
|
|
||||||
render,
|
|
||||||
setup_test_homeserver,
|
|
||||||
)
|
|
||||||
|
|
||||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||||
|
|
||||||
|
|
||||||
class FilterTestCase(unittest.TestCase):
|
class FilterTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
USER_ID = "@apple:test"
|
user_id = "@apple:test"
|
||||||
|
hijack_auth = True
|
||||||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||||
TO_REGISTER = [filter]
|
servlets = [filter.register_servlets]
|
||||||
|
|
||||||
def setUp(self):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.clock = MemoryReactorClock()
|
self.filtering = hs.get_filtering()
|
||||||
self.hs_clock = Clock(self.clock)
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
self.hs = setup_test_homeserver(
|
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth = self.hs.get_auth()
|
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
|
||||||
return {
|
|
||||||
"user": UserID.from_string(self.USER_ID),
|
|
||||||
"token_id": 1,
|
|
||||||
"is_guest": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
|
||||||
return synapse.types.create_requester(
|
|
||||||
UserID.from_string(self.USER_ID), 1, False, None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
|
||||||
self.auth.get_user_by_req = get_user_by_req
|
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
|
||||||
self.filtering = self.hs.get_filtering()
|
|
||||||
self.resource = JsonResource(self.hs)
|
|
||||||
|
|
||||||
for r in self.TO_REGISTER:
|
|
||||||
r.register_servlets(self.hs, self.resource)
|
|
||||||
|
|
||||||
def test_add_filter(self):
|
def test_add_filter(self):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||||
self.clock.advance(0)
|
self.pump()
|
||||||
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
|
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
|
||||||
|
|
||||||
def test_add_filter_for_other_user(self):
|
def test_add_filter_for_other_user(self):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||||
|
@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
|
||||||
def test_add_filter_non_local_user(self):
|
def test_add_filter_non_local_user(self):
|
||||||
_is_mine = self.hs.is_mine
|
_is_mine = self.hs.is_mine
|
||||||
self.hs.is_mine = lambda target_user: False
|
self.hs.is_mine = lambda target_user: False
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.hs.is_mine = _is_mine
|
self.hs.is_mine = _is_mine
|
||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
|
@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
|
||||||
filter_id = self.filtering.add_user_filter(
|
filter_id = self.filtering.add_user_filter(
|
||||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||||
)
|
)
|
||||||
self.clock.advance(1)
|
self.reactor.advance(1)
|
||||||
filter_id = filter_id.result
|
filter_id = filter_id.result
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
||||||
|
|
||||||
def test_get_filter_non_existant(self):
|
def test_get_filter_non_existant(self):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||||
|
@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
|
||||||
# Currently invalid params do not have an appropriate errcode
|
# Currently invalid params do not have an appropriate errcode
|
||||||
# in errors.py
|
# in errors.py
|
||||||
def test_get_filter_invalid_id(self):
|
def test_get_filter_invalid_id(self):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
|
|
||||||
# No ID also returns an invalid_id error
|
# No ID also returns an invalid_id error
|
||||||
def test_get_filter_no_id(self):
|
def test_get_filter_no_id(self):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
|
|
|
@ -3,22 +3,19 @@ import json
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
from twisted.test.proto_helpers import MemoryReactorClock
|
|
||||||
|
|
||||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||||
from synapse.http.server import JsonResource
|
|
||||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request, render, setup_test_homeserver
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServletTestCase(unittest.TestCase):
|
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
|
||||||
|
servlets = [register_servlets]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
self.clock = MemoryReactorClock()
|
|
||||||
self.hs_clock = Clock(self.clock)
|
|
||||||
self.url = b"/_matrix/client/r0/register"
|
self.url = b"/_matrix/client/r0/register"
|
||||||
|
|
||||||
self.appservice = None
|
self.appservice = None
|
||||||
|
@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
identity_handler=self.identity_handler,
|
identity_handler=self.identity_handler,
|
||||||
login_handler=self.login_handler,
|
login_handler=self.login_handler,
|
||||||
)
|
)
|
||||||
self.hs = setup_test_homeserver(
|
self.hs = self.setup_test_homeserver()
|
||||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
|
||||||
)
|
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
|
@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.config.registrations_require_3pid = []
|
self.hs.config.registrations_require_3pid = []
|
||||||
self.hs.config.auto_join_rooms = []
|
self.hs.config.auto_join_rooms = []
|
||||||
|
|
||||||
self.resource = JsonResource(self.hs)
|
return self.hs
|
||||||
register_servlets(self.hs, self.resource)
|
|
||||||
|
|
||||||
def test_POST_appservice_registration_valid(self):
|
def test_POST_appservice_registration_valid(self):
|
||||||
user_id = "@kermit:muppet"
|
user_id = "@kermit:muppet"
|
||||||
|
@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||||
request_data = json.dumps({"username": "kermit"})
|
request_data = json.dumps({"username": "kermit"})
|
||||||
|
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
det_data = {
|
det_data = {
|
||||||
|
@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
def test_POST_appservice_registration_invalid(self):
|
def test_POST_appservice_registration_invalid(self):
|
||||||
self.appservice = None # no application service exists
|
self.appservice = None # no application service exists
|
||||||
request_data = json.dumps({"username": "kermit"})
|
request_data = json.dumps({"username": "kermit"})
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||||
|
|
||||||
def test_POST_bad_password(self):
|
def test_POST_bad_password(self):
|
||||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Invalid password")
|
self.assertEquals(channel.json_body["error"], "Invalid password")
|
||||||
|
|
||||||
def test_POST_bad_username(self):
|
def test_POST_bad_username(self):
|
||||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Invalid username")
|
self.assertEquals(channel.json_body["error"], "Invalid username")
|
||||||
|
@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||||
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||||
|
@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.config.allow_guest_access = True
|
self.hs.config.allow_guest_access = True
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
def test_POST_disabled_guest_registration(self):
|
def test_POST_disabled_guest_registration(self):
|
||||||
self.hs.config.allow_guest_access = False
|
self.hs.config.allow_guest_access = False
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
render(request, self.resource, self.clock)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
||||||
|
|
|
@ -34,6 +34,7 @@ class FakeChannel(object):
|
||||||
wire).
|
wire).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_reactor = attr.ib()
|
||||||
result = attr.ib(default=attr.Factory(dict))
|
result = attr.ib(default=attr.Factory(dict))
|
||||||
_producer = None
|
_producer = None
|
||||||
|
|
||||||
|
@ -63,6 +64,15 @@ class FakeChannel(object):
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
def registerProducer(self, producer, streaming):
|
||||||
self._producer = producer
|
self._producer = producer
|
||||||
|
self.producerStreaming = streaming
|
||||||
|
|
||||||
|
def _produce():
|
||||||
|
if self._producer:
|
||||||
|
self._producer.resumeProducing()
|
||||||
|
self._reactor.callLater(0.1, _produce)
|
||||||
|
|
||||||
|
if not streaming:
|
||||||
|
self._reactor.callLater(0.0, _produce)
|
||||||
|
|
||||||
def unregisterProducer(self):
|
def unregisterProducer(self):
|
||||||
if self._producer is None:
|
if self._producer is None:
|
||||||
|
@ -105,7 +115,13 @@ class FakeSite:
|
||||||
|
|
||||||
|
|
||||||
def make_request(
|
def make_request(
|
||||||
method, path, content=b"", access_token=None, request=SynapseRequest, shorthand=True
|
reactor,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
content=b"",
|
||||||
|
access_token=None,
|
||||||
|
request=SynapseRequest,
|
||||||
|
shorthand=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Make a web request using the given method and path, feed it the
|
Make a web request using the given method and path, feed it the
|
||||||
|
@ -138,7 +154,7 @@ def make_request(
|
||||||
content = content.encode('utf8')
|
content = content.encode('utf8')
|
||||||
|
|
||||||
site = FakeSite()
|
site = FakeSite()
|
||||||
channel = FakeChannel()
|
channel = FakeChannel(reactor)
|
||||||
|
|
||||||
req = request(site, channel)
|
req = request(site, channel)
|
||||||
req.process = lambda: b""
|
req.process = lambda: b""
|
||||||
|
|
|
@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.http.server import JsonResource
|
|
||||||
from synapse.rest.client.v2_alpha import register, sync
|
from synapse.rest.client.v2_alpha import register, sync
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import (
|
|
||||||
ThreadedMemoryReactorClock,
|
|
||||||
make_request,
|
|
||||||
render,
|
|
||||||
setup_test_homeserver,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMauLimit(unittest.TestCase):
|
class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
|
||||||
self.clock = Clock(self.reactor)
|
|
||||||
|
|
||||||
self.hs = setup_test_homeserver(
|
servlets = [register.register_servlets, sync.register_servlets]
|
||||||
self.addCleanup,
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
|
self.hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red",
|
||||||
http_client=None,
|
http_client=None,
|
||||||
clock=self.clock,
|
|
||||||
reactor=self.reactor,
|
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||||
)
|
)
|
||||||
|
@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase):
|
||||||
self.hs.config.server_notices_mxid_display_name = None
|
self.hs.config.server_notices_mxid_display_name = None
|
||||||
self.hs.config.server_notices_mxid_avatar_url = None
|
self.hs.config.server_notices_mxid_avatar_url = None
|
||||||
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
||||||
|
return self.hs
|
||||||
self.resource = JsonResource(self.hs)
|
|
||||||
register.register_servlets(self.hs, self.resource)
|
|
||||||
sync.register_servlets(self.hs, self.resource)
|
|
||||||
|
|
||||||
def test_simple_deny_mau(self):
|
def test_simple_deny_mau(self):
|
||||||
# Create and sync so that the MAU counts get updated
|
# Create and sync so that the MAU counts get updated
|
||||||
|
@ -193,8 +180,8 @@ class TestMauLimit(unittest.TestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = make_request("POST", "/register", request_data)
|
request, channel = self.make_request("POST", "/register", request_data)
|
||||||
render(request, self.resource, self.reactor)
|
self.render(request)
|
||||||
|
|
||||||
if channel.code != 200:
|
if channel.code != 200:
|
||||||
raise HttpResponseException(
|
raise HttpResponseException(
|
||||||
|
@ -206,10 +193,10 @@ class TestMauLimit(unittest.TestCase):
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def do_sync_for_user(self, token):
|
def do_sync_for_user(self, token):
|
||||||
request, channel = make_request(
|
request, channel = self.make_request(
|
||||||
"GET", "/sync", access_token=token
|
"GET", "/sync", access_token=token
|
||||||
)
|
)
|
||||||
render(request, self.resource, self.reactor)
|
self.render(request)
|
||||||
|
|
||||||
if channel.code != 200:
|
if channel.code != 200:
|
||||||
raise HttpResponseException(
|
raise HttpResponseException(
|
||||||
|
|
|
@ -57,7 +57,9 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
|
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
request, channel = make_request(
|
||||||
|
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
|
||||||
|
)
|
||||||
render(request, res, self.reactor)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
||||||
|
@ -75,7 +77,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res = JsonResource(self.homeserver)
|
res = JsonResource(self.homeserver)
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||||
render(request, res, self.reactor)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'500')
|
self.assertEqual(channel.result["code"], b'500')
|
||||||
|
@ -98,7 +100,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res = JsonResource(self.homeserver)
|
res = JsonResource(self.homeserver)
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||||
render(request, res, self.reactor)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'500')
|
self.assertEqual(channel.result["code"], b'500')
|
||||||
|
@ -115,7 +117,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res = JsonResource(self.homeserver)
|
res = JsonResource(self.homeserver)
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||||
render(request, res, self.reactor)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'403')
|
self.assertEqual(channel.result["code"], b'403')
|
||||||
|
@ -136,7 +138,7 @@ class JsonResourceTests(unittest.TestCase):
|
||||||
res = JsonResource(self.homeserver)
|
res = JsonResource(self.homeserver)
|
||||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||||
|
|
||||||
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
|
||||||
render(request, res, self.reactor)
|
render(request, res, self.reactor)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b'400')
|
self.assertEqual(channel.result["code"], b'400')
|
||||||
|
|
|
@ -23,7 +23,6 @@ from synapse.rest.client.v2_alpha.register import register_servlets
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import make_request
|
|
||||||
|
|
||||||
|
|
||||||
class TermsTestCase(unittest.HomeserverTestCase):
|
class TermsTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -92,7 +91,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
|
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
# We don't bother checking that the response is correct - we'll leave that to
|
# We don't bother checking that the response is correct - we'll leave that to
|
||||||
|
@ -110,7 +109,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
request, channel = make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
# We're interested in getting a response that looks like a successful
|
# We're interested in getting a response that looks like a successful
|
||||||
|
|
|
@ -189,11 +189,11 @@ class HomeserverTestCase(TestCase):
|
||||||
for servlet in self.servlets:
|
for servlet in self.servlets:
|
||||||
servlet(self.hs, self.resource)
|
servlet(self.hs, self.resource)
|
||||||
|
|
||||||
|
from tests.rest.client.v1.utils import RestHelper
|
||||||
|
|
||||||
|
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
|
||||||
|
|
||||||
if hasattr(self, "user_id"):
|
if hasattr(self, "user_id"):
|
||||||
from tests.rest.client.v1.utils import RestHelper
|
|
||||||
|
|
||||||
self.helper = RestHelper(self.hs, self.resource, self.user_id)
|
|
||||||
|
|
||||||
if self.hijack_auth:
|
if self.hijack_auth:
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
def get_user_by_access_token(token=None, allow_guest=False):
|
||||||
|
@ -285,7 +285,9 @@ class HomeserverTestCase(TestCase):
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
content = json.dumps(content).encode('utf8')
|
content = json.dumps(content).encode('utf8')
|
||||||
|
|
||||||
return make_request(method, path, content, access_token, request, shorthand)
|
return make_request(
|
||||||
|
self.reactor, method, path, content, access_token, request, shorthand
|
||||||
|
)
|
||||||
|
|
||||||
def render(self, request):
|
def render(self, request):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue