mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 04:33:47 +01:00
Run Black on the tests again (#5170)
This commit is contained in:
parent
d9a02d1201
commit
b36c82576e
54 changed files with 818 additions and 1158 deletions
1
changelog.d/5170.misc
Normal file
1
changelog.d/5170.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Run `black` on the tests directory.
|
|
@ -109,7 +109,6 @@ class FilteringTestCase(unittest.TestCase):
|
||||||
"event_format": "client",
|
"event_format": "client",
|
||||||
"event_fields": ["type", "content", "sender"],
|
"event_fields": ["type", "content", "sender"],
|
||||||
},
|
},
|
||||||
|
|
||||||
# a single backslash should be permitted (though it is debatable whether
|
# a single backslash should be permitted (though it is debatable whether
|
||||||
# it should be permitted before anything other than `.`, and what that
|
# it should be permitted before anything other than `.`, and what that
|
||||||
# actually means)
|
# actually means)
|
||||||
|
|
|
@ -10,19 +10,19 @@ class TestRatelimiter(unittest.TestCase):
|
||||||
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertFalse(allowed)
|
self.assertFalse(allowed)
|
||||||
self.assertEquals(10., time_allowed)
|
self.assertEquals(10.0, time_allowed)
|
||||||
|
|
||||||
allowed, time_allowed = limiter.can_do_action(
|
allowed, time_allowed = limiter.can_do_action(
|
||||||
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
|
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
|
||||||
)
|
)
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEquals(20., time_allowed)
|
self.assertEquals(20.0, time_allowed)
|
||||||
|
|
||||||
def test_pruning(self):
|
def test_pruning(self):
|
||||||
limiter = Ratelimiter()
|
limiter = Ratelimiter()
|
||||||
|
|
|
@ -25,16 +25,18 @@ from tests.unittest import HomeserverTestCase
|
||||||
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
|
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
http_client=None, homeserverToUse=FederationReaderServer,
|
http_client=None, homeserverToUse=FederationReaderServer
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
@parameterized.expand([
|
@parameterized.expand(
|
||||||
|
[
|
||||||
(["federation"], "auth_fail"),
|
(["federation"], "auth_fail"),
|
||||||
([], "no_resource"),
|
([], "no_resource"),
|
||||||
(["openid", "federation"], "auth_fail"),
|
(["openid", "federation"], "auth_fail"),
|
||||||
(["openid"], "auth_fail"),
|
(["openid"], "auth_fail"),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def test_openid_listener(self, names, expectation):
|
def test_openid_listener(self, names, expectation):
|
||||||
"""
|
"""
|
||||||
Test different openid listener configurations.
|
Test different openid listener configurations.
|
||||||
|
@ -53,17 +55,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
|
||||||
# Grab the resource from the site that was told to listen
|
# Grab the resource from the site that was told to listen
|
||||||
site = self.reactor.tcpServers[0][1]
|
site = self.reactor.tcpServers[0][1]
|
||||||
try:
|
try:
|
||||||
self.resource = (
|
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
|
||||||
site.resource.children[b"_matrix"].children[b"federation"]
|
|
||||||
)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if expectation == "no_resource":
|
if expectation == "no_resource":
|
||||||
return
|
return
|
||||||
raise
|
raise
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", "/_matrix/federation/v1/openid/userinfo"
|
||||||
"/_matrix/federation/v1/openid/userinfo",
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
|
@ -74,16 +73,18 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
|
||||||
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
|
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
http_client=None, homeserverToUse=SynapseHomeServer,
|
http_client=None, homeserverToUse=SynapseHomeServer
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
@parameterized.expand([
|
@parameterized.expand(
|
||||||
|
[
|
||||||
(["federation"], "auth_fail"),
|
(["federation"], "auth_fail"),
|
||||||
([], "no_resource"),
|
([], "no_resource"),
|
||||||
(["openid", "federation"], "auth_fail"),
|
(["openid", "federation"], "auth_fail"),
|
||||||
(["openid"], "auth_fail"),
|
(["openid"], "auth_fail"),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
def test_openid_listener(self, names, expectation):
|
def test_openid_listener(self, names, expectation):
|
||||||
"""
|
"""
|
||||||
Test different openid listener configurations.
|
Test different openid listener configurations.
|
||||||
|
@ -102,17 +103,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
|
||||||
# Grab the resource from the site that was told to listen
|
# Grab the resource from the site that was told to listen
|
||||||
site = self.reactor.tcpServers[0][1]
|
site = self.reactor.tcpServers[0][1]
|
||||||
try:
|
try:
|
||||||
self.resource = (
|
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
|
||||||
site.resource.children[b"_matrix"].children[b"federation"]
|
|
||||||
)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if expectation == "no_resource":
|
if expectation == "no_resource":
|
||||||
return
|
return
|
||||||
raise
|
raise
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", "/_matrix/federation/v1/openid/userinfo"
|
||||||
"/_matrix/federation/v1/openid/userinfo",
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
|
|
|
@ -45,13 +45,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
set(
|
set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
|
||||||
[
|
|
||||||
"homeserver.yaml",
|
|
||||||
"lemurs.win.log.config",
|
|
||||||
"lemurs.win.signing.key",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
set(os.listdir(self.dir)),
|
set(os.listdir(self.dir)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,8 @@ from tests import unittest
|
||||||
|
|
||||||
class RoomDirectoryConfigTestCase(unittest.TestCase):
|
class RoomDirectoryConfigTestCase(unittest.TestCase):
|
||||||
def test_alias_creation_acl(self):
|
def test_alias_creation_acl(self):
|
||||||
config = yaml.safe_load("""
|
config = yaml.safe_load(
|
||||||
|
"""
|
||||||
alias_creation_rules:
|
alias_creation_rules:
|
||||||
- user_id: "*bob*"
|
- user_id: "*bob*"
|
||||||
alias: "*"
|
alias: "*"
|
||||||
|
@ -38,43 +39,49 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
|
||||||
action: "allow"
|
action: "allow"
|
||||||
|
|
||||||
room_list_publication_rules: []
|
room_list_publication_rules: []
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
rd_config = RoomDirectoryConfig()
|
rd_config = RoomDirectoryConfig()
|
||||||
rd_config.read_config(config)
|
rd_config.read_config(config)
|
||||||
|
|
||||||
self.assertFalse(rd_config.is_alias_creation_allowed(
|
self.assertFalse(
|
||||||
user_id="@bob:example.com",
|
rd_config.is_alias_creation_allowed(
|
||||||
room_id="!test",
|
user_id="@bob:example.com", room_id="!test", alias="#test:example.com"
|
||||||
alias="#test:example.com",
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_alias_creation_allowed(
|
self.assertTrue(
|
||||||
|
rd_config.is_alias_creation_allowed(
|
||||||
user_id="@test:example.com",
|
user_id="@test:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
alias="#unofficial_st:example.com",
|
alias="#unofficial_st:example.com",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_alias_creation_allowed(
|
self.assertTrue(
|
||||||
|
rd_config.is_alias_creation_allowed(
|
||||||
user_id="@foobar:example.com",
|
user_id="@foobar:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
alias="#test:example.com",
|
alias="#test:example.com",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_alias_creation_allowed(
|
self.assertTrue(
|
||||||
user_id="@gah:example.com",
|
rd_config.is_alias_creation_allowed(
|
||||||
room_id="!test",
|
user_id="@gah:example.com", room_id="!test", alias="#goo:example.com"
|
||||||
alias="#goo:example.com",
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
self.assertFalse(rd_config.is_alias_creation_allowed(
|
self.assertFalse(
|
||||||
user_id="@test:example.com",
|
rd_config.is_alias_creation_allowed(
|
||||||
room_id="!test",
|
user_id="@test:example.com", room_id="!test", alias="#test:example.com"
|
||||||
alias="#test:example.com",
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
def test_room_publish_acl(self):
|
def test_room_publish_acl(self):
|
||||||
config = yaml.safe_load("""
|
config = yaml.safe_load(
|
||||||
|
"""
|
||||||
alias_creation_rules: []
|
alias_creation_rules: []
|
||||||
|
|
||||||
room_list_publication_rules:
|
room_list_publication_rules:
|
||||||
|
@ -92,55 +99,66 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
|
||||||
action: "allow"
|
action: "allow"
|
||||||
- room_id: "!test-deny"
|
- room_id: "!test-deny"
|
||||||
action: "deny"
|
action: "deny"
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
rd_config = RoomDirectoryConfig()
|
rd_config = RoomDirectoryConfig()
|
||||||
rd_config.read_config(config)
|
rd_config.read_config(config)
|
||||||
|
|
||||||
self.assertFalse(rd_config.is_publishing_room_allowed(
|
self.assertFalse(
|
||||||
|
rd_config.is_publishing_room_allowed(
|
||||||
user_id="@bob:example.com",
|
user_id="@bob:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
aliases=["#test:example.com"],
|
aliases=["#test:example.com"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_publishing_room_allowed(
|
self.assertTrue(
|
||||||
|
rd_config.is_publishing_room_allowed(
|
||||||
user_id="@test:example.com",
|
user_id="@test:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
aliases=["#unofficial_st:example.com"],
|
aliases=["#unofficial_st:example.com"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_publishing_room_allowed(
|
self.assertTrue(
|
||||||
user_id="@foobar:example.com",
|
rd_config.is_publishing_room_allowed(
|
||||||
room_id="!test",
|
user_id="@foobar:example.com", room_id="!test", aliases=[]
|
||||||
aliases=[],
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_publishing_room_allowed(
|
self.assertTrue(
|
||||||
|
rd_config.is_publishing_room_allowed(
|
||||||
user_id="@gah:example.com",
|
user_id="@gah:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
aliases=["#goo:example.com"],
|
aliases=["#goo:example.com"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertFalse(rd_config.is_publishing_room_allowed(
|
self.assertFalse(
|
||||||
|
rd_config.is_publishing_room_allowed(
|
||||||
user_id="@test:example.com",
|
user_id="@test:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
aliases=["#test:example.com"],
|
aliases=["#test:example.com"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_publishing_room_allowed(
|
self.assertTrue(
|
||||||
user_id="@foobar:example.com",
|
rd_config.is_publishing_room_allowed(
|
||||||
room_id="!test-deny",
|
user_id="@foobar:example.com", room_id="!test-deny", aliases=[]
|
||||||
aliases=[],
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
self.assertFalse(rd_config.is_publishing_room_allowed(
|
self.assertFalse(
|
||||||
user_id="@gah:example.com",
|
rd_config.is_publishing_room_allowed(
|
||||||
room_id="!test-deny",
|
user_id="@gah:example.com", room_id="!test-deny", aliases=[]
|
||||||
aliases=[],
|
)
|
||||||
))
|
)
|
||||||
|
|
||||||
self.assertTrue(rd_config.is_publishing_room_allowed(
|
self.assertTrue(
|
||||||
|
rd_config.is_publishing_room_allowed(
|
||||||
user_id="@test:example.com",
|
user_id="@test:example.com",
|
||||||
room_id="!test",
|
room_id="!test",
|
||||||
aliases=["#unofficial_st:example.com", "#blah:example.com"],
|
aliases=["#unofficial_st:example.com", "#blah:example.com"],
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
|
@ -19,7 +19,6 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class ServerConfigTestCase(unittest.TestCase):
|
class ServerConfigTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def test_is_threepid_reserved(self):
|
def test_is_threepid_reserved(self):
|
||||||
user1 = {'medium': 'email', 'address': 'user1@example.com'}
|
user1 = {'medium': 'email', 'address': 'user1@example.com'}
|
||||||
user2 = {'medium': 'email', 'address': 'user2@example.com'}
|
user2 = {'medium': 'email', 'address': 'user2@example.com'}
|
||||||
|
|
|
@ -26,7 +26,6 @@ class TestConfig(TlsConfig):
|
||||||
|
|
||||||
|
|
||||||
class TLSConfigTests(TestCase):
|
class TLSConfigTests(TestCase):
|
||||||
|
|
||||||
def test_warn_self_signed(self):
|
def test_warn_self_signed(self):
|
||||||
"""
|
"""
|
||||||
Synapse will give a warning when it loads a self-signed certificate.
|
Synapse will give a warning when it loads a self-signed certificate.
|
||||||
|
@ -34,7 +33,8 @@ class TLSConfigTests(TestCase):
|
||||||
config_dir = self.mktemp()
|
config_dir = self.mktemp()
|
||||||
os.mkdir(config_dir)
|
os.mkdir(config_dir)
|
||||||
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
|
with open(os.path.join(config_dir, "cert.pem"), 'w') as f:
|
||||||
f.write("""-----BEGIN CERTIFICATE-----
|
f.write(
|
||||||
|
"""-----BEGIN CERTIFICATE-----
|
||||||
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
|
MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN
|
||||||
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
|
BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv
|
||||||
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
|
Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb
|
||||||
|
@ -56,11 +56,12 @@ I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj
|
||||||
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
|
iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2
|
||||||
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
|
SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz
|
||||||
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
||||||
-----END CERTIFICATE-----""")
|
-----END CERTIFICATE-----"""
|
||||||
|
)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
|
"tls_certificate_path": os.path.join(config_dir, "cert.pem"),
|
||||||
"tls_fingerprints": []
|
"tls_fingerprints": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
t = TestConfig()
|
t = TestConfig()
|
||||||
|
@ -75,5 +76,5 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg=
|
||||||
"Self-signed TLS certificates will not be accepted by "
|
"Self-signed TLS certificates will not be accepted by "
|
||||||
"Synapse 1.0. Please either provide a valid certificate, "
|
"Synapse 1.0. Please either provide a valid certificate, "
|
||||||
"or use Synapse's ACME support to provision one."
|
"or use Synapse's ACME support to provision one."
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -169,7 +169,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
self.http_client.post_json.return_value = defer.Deferred()
|
self.http_client.post_json.return_value = defer.Deferred()
|
||||||
|
|
||||||
res_deferreds_2 = kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server10", json1, )]
|
[("server10", json1)]
|
||||||
)
|
)
|
||||||
res_deferreds_2[0].addBoth(self.check_context, None)
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
|
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
|
||||||
|
@ -345,6 +345,7 @@ def _verify_json_for_server(keyring, server_name, json_object):
|
||||||
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
|
||||||
with the patched defer.inlineCallbacks.
|
with the patched defer.inlineCallbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def v():
|
def v():
|
||||||
rv1 = yield keyring.verify_json_for_server(server_name, json_object)
|
rv1 = yield keyring.verify_json_for_server(server_name, json_object)
|
||||||
|
|
|
@ -33,11 +33,15 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
mock_state_handler = self.hs.get_state_handler()
|
mock_state_handler = self.hs.get_state_handler()
|
||||||
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
|
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
|
||||||
|
|
||||||
mock_send_transaction = self.hs.get_federation_transport_client().send_transaction
|
mock_send_transaction = (
|
||||||
|
self.hs.get_federation_transport_client().send_transaction
|
||||||
|
)
|
||||||
mock_send_transaction.return_value = defer.succeed({})
|
mock_send_transaction.return_value = defer.succeed({})
|
||||||
|
|
||||||
sender = self.hs.get_federation_sender()
|
sender = self.hs.get_federation_sender()
|
||||||
receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234})
|
receipt = ReadReceipt(
|
||||||
|
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||||
|
)
|
||||||
self.successResultOf(sender.send_read_receipt(receipt))
|
self.successResultOf(sender.send_read_receipt(receipt))
|
||||||
|
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -46,7 +50,9 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
mock_send_transaction.assert_called_once()
|
mock_send_transaction.assert_called_once()
|
||||||
json_cb = mock_send_transaction.call_args[0][1]
|
json_cb = mock_send_transaction.call_args[0][1]
|
||||||
data = json_cb()
|
data = json_cb()
|
||||||
self.assertEqual(data['edus'], [
|
self.assertEqual(
|
||||||
|
data['edus'],
|
||||||
|
[
|
||||||
{
|
{
|
||||||
'edu_type': 'm.receipt',
|
'edu_type': 'm.receipt',
|
||||||
'content': {
|
'content': {
|
||||||
|
@ -55,12 +61,13 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
'user_id': {
|
'user_id': {
|
||||||
'event_ids': ['event_id'],
|
'event_ids': ['event_id'],
|
||||||
'data': {'ts': 1234},
|
'data': {'ts': 1234},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
},
|
],
|
||||||
},
|
)
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
||||||
def test_send_receipts_with_backoff(self):
|
def test_send_receipts_with_backoff(self):
|
||||||
"""Send two receipts in quick succession; the second should be flushed, but
|
"""Send two receipts in quick succession; the second should be flushed, but
|
||||||
|
@ -68,11 +75,15 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
mock_state_handler = self.hs.get_state_handler()
|
mock_state_handler = self.hs.get_state_handler()
|
||||||
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
|
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
|
||||||
|
|
||||||
mock_send_transaction = self.hs.get_federation_transport_client().send_transaction
|
mock_send_transaction = (
|
||||||
|
self.hs.get_federation_transport_client().send_transaction
|
||||||
|
)
|
||||||
mock_send_transaction.return_value = defer.succeed({})
|
mock_send_transaction.return_value = defer.succeed({})
|
||||||
|
|
||||||
sender = self.hs.get_federation_sender()
|
sender = self.hs.get_federation_sender()
|
||||||
receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234})
|
receipt = ReadReceipt(
|
||||||
|
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||||
|
)
|
||||||
self.successResultOf(sender.send_read_receipt(receipt))
|
self.successResultOf(sender.send_read_receipt(receipt))
|
||||||
|
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -81,7 +92,9 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
mock_send_transaction.assert_called_once()
|
mock_send_transaction.assert_called_once()
|
||||||
json_cb = mock_send_transaction.call_args[0][1]
|
json_cb = mock_send_transaction.call_args[0][1]
|
||||||
data = json_cb()
|
data = json_cb()
|
||||||
self.assertEqual(data['edus'], [
|
self.assertEqual(
|
||||||
|
data['edus'],
|
||||||
|
[
|
||||||
{
|
{
|
||||||
'edu_type': 'm.receipt',
|
'edu_type': 'm.receipt',
|
||||||
'content': {
|
'content': {
|
||||||
|
@ -90,16 +103,19 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
'user_id': {
|
'user_id': {
|
||||||
'event_ids': ['event_id'],
|
'event_ids': ['event_id'],
|
||||||
'data': {'ts': 1234},
|
'data': {'ts': 1234},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
},
|
],
|
||||||
},
|
)
|
||||||
},
|
|
||||||
])
|
|
||||||
mock_send_transaction.reset_mock()
|
mock_send_transaction.reset_mock()
|
||||||
|
|
||||||
# send the second RR
|
# send the second RR
|
||||||
receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234})
|
receipt = ReadReceipt(
|
||||||
|
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
|
||||||
|
)
|
||||||
self.successResultOf(sender.send_read_receipt(receipt))
|
self.successResultOf(sender.send_read_receipt(receipt))
|
||||||
self.pump()
|
self.pump()
|
||||||
mock_send_transaction.assert_not_called()
|
mock_send_transaction.assert_not_called()
|
||||||
|
@ -111,7 +127,9 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
mock_send_transaction.assert_called_once()
|
mock_send_transaction.assert_called_once()
|
||||||
json_cb = mock_send_transaction.call_args[0][1]
|
json_cb = mock_send_transaction.call_args[0][1]
|
||||||
data = json_cb()
|
data = json_cb()
|
||||||
self.assertEqual(data['edus'], [
|
self.assertEqual(
|
||||||
|
data['edus'],
|
||||||
|
[
|
||||||
{
|
{
|
||||||
'edu_type': 'm.receipt',
|
'edu_type': 'm.receipt',
|
||||||
'content': {
|
'content': {
|
||||||
|
@ -120,9 +138,10 @@ class FederationSenderTestCases(HomeserverTestCase):
|
||||||
'user_id': {
|
'user_id': {
|
||||||
'event_ids': ['other_id'],
|
'event_ids': ['other_id'],
|
||||||
'data': {'ts': 1234},
|
'data': {'ts': 1234},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
},
|
],
|
||||||
},
|
)
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
|
@ -115,11 +115,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
|
||||||
# We cheekily override the config to add custom alias creation rules
|
# We cheekily override the config to add custom alias creation rules
|
||||||
config = {}
|
config = {}
|
||||||
config["alias_creation_rules"] = [
|
config["alias_creation_rules"] = [
|
||||||
{
|
{"user_id": "*", "alias": "#unofficial_*", "action": "allow"}
|
||||||
"user_id": "*",
|
|
||||||
"alias": "#unofficial_*",
|
|
||||||
"action": "allow",
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
config["room_list_publication_rules"] = []
|
config["room_list_publication_rules"] = []
|
||||||
|
|
||||||
|
@ -162,9 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT",
|
"PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
|
||||||
b"directory/list/room/%s" % (room_id.encode('ascii'),),
|
|
||||||
b'{}',
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
@ -179,10 +173,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
self.directory_handler.enable_room_list_search = True
|
self.directory_handler.enable_room_list_search = True
|
||||||
|
|
||||||
# Room list is enabled so we should get some results
|
# Room list is enabled so we should get some results
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request("GET", b"publicRooms")
|
||||||
"GET",
|
|
||||||
b"publicRooms",
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
self.assertTrue(len(channel.json_body["chunk"]) > 0)
|
self.assertTrue(len(channel.json_body["chunk"]) > 0)
|
||||||
|
@ -191,10 +182,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
self.directory_handler.enable_room_list_search = False
|
self.directory_handler.enable_room_list_search = False
|
||||||
|
|
||||||
# Room list disabled so we should get no results
|
# Room list disabled so we should get no results
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request("GET", b"publicRooms")
|
||||||
"GET",
|
|
||||||
b"publicRooms",
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
self.assertTrue(len(channel.json_body["chunk"]) == 0)
|
self.assertTrue(len(channel.json_body["chunk"]) == 0)
|
||||||
|
@ -202,9 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
|
||||||
# Room list disabled so we shouldn't be allowed to publish rooms
|
# Room list disabled so we shouldn't be allowed to publish rooms
|
||||||
room_id = self.helper.create_room_as(self.user_id)
|
room_id = self.helper.create_room_as(self.user_id)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT",
|
"PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
|
||||||
b"directory/list/room/%s" % (room_id.encode('ascii'),),
|
|
||||||
b'{}',
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(403, channel.code, channel.result)
|
self.assertEquals(403, channel.code, channel.result)
|
||||||
|
|
|
@ -36,7 +36,7 @@ room_keys = {
|
||||||
"first_message_index": 1,
|
"first_message_index": 1,
|
||||||
"forwarded_count": 1,
|
"forwarded_count": 1,
|
||||||
"is_verified": False,
|
"is_verified": False,
|
||||||
"session_data": "SSBBTSBBIEZJU0gK"
|
"session_data": "SSBBTSBBIEZJU0gK",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,9 +53,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.hs = yield utils.setup_test_homeserver(
|
self.hs = yield utils.setup_test_homeserver(
|
||||||
self.addCleanup,
|
self.addCleanup, handlers=None, replication_layer=mock.Mock()
|
||||||
handlers=None,
|
|
||||||
replication_layer=mock.Mock(),
|
|
||||||
)
|
)
|
||||||
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
|
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
|
||||||
self.local_user = "@boris:" + self.hs.hostname
|
self.local_user = "@boris:" + self.hs.hostname
|
||||||
|
@ -88,67 +86,86 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_create_version(self):
|
def test_create_version(self):
|
||||||
"""Check that we can create and then retrieve versions.
|
"""Check that we can create and then retrieve versions.
|
||||||
"""
|
"""
|
||||||
res = yield self.handler.create_version(self.local_user, {
|
res = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(res, "1")
|
self.assertEqual(res, "1")
|
||||||
|
|
||||||
# check we can retrieve it as the current version
|
# check we can retrieve it as the current version
|
||||||
res = yield self.handler.get_version_info(self.local_user)
|
res = yield self.handler.get_version_info(self.local_user)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "first_version_auth_data",
|
"auth_data": "first_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# check we can retrieve it as a specific version
|
# check we can retrieve it as a specific version
|
||||||
res = yield self.handler.get_version_info(self.local_user, "1")
|
res = yield self.handler.get_version_info(self.local_user, "1")
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "first_version_auth_data",
|
"auth_data": "first_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# upload a new one...
|
# upload a new one...
|
||||||
res = yield self.handler.create_version(self.local_user, {
|
res = yield self.handler.create_version(
|
||||||
|
self.local_user,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "second_version_auth_data",
|
"auth_data": "second_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
self.assertEqual(res, "2")
|
self.assertEqual(res, "2")
|
||||||
|
|
||||||
# check we can retrieve it as the current version
|
# check we can retrieve it as the current version
|
||||||
res = yield self.handler.get_version_info(self.local_user)
|
res = yield self.handler.get_version_info(self.local_user)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
"version": "2",
|
"version": "2",
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "second_version_auth_data",
|
"auth_data": "second_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_version(self):
|
def test_update_version(self):
|
||||||
"""Check that we can update versions.
|
"""Check that we can update versions.
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
res = yield self.handler.update_version(self.local_user, version, {
|
res = yield self.handler.update_version(
|
||||||
|
self.local_user,
|
||||||
|
version,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "revised_first_version_auth_data",
|
"auth_data": "revised_first_version_auth_data",
|
||||||
"version": version
|
"version": version,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
self.assertDictEqual(res, {})
|
self.assertDictEqual(res, {})
|
||||||
|
|
||||||
# check we can retrieve it as the current version
|
# check we can retrieve it as the current version
|
||||||
res = yield self.handler.get_version_info(self.local_user)
|
res = yield self.handler.get_version_info(self.local_user)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "revised_first_version_auth_data",
|
"auth_data": "revised_first_version_auth_data",
|
||||||
"version": version
|
"version": version,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_missing_version(self):
|
def test_update_missing_version(self):
|
||||||
|
@ -156,11 +173,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
yield self.handler.update_version(self.local_user, "1", {
|
yield self.handler.update_version(
|
||||||
|
self.local_user,
|
||||||
|
"1",
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "revised_first_version_auth_data",
|
"auth_data": "revised_first_version_auth_data",
|
||||||
"version": "1"
|
"version": "1",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
except errors.SynapseError as e:
|
except errors.SynapseError as e:
|
||||||
res = e.code
|
res = e.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
@ -170,29 +191,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
"""Check that we get a 400 if the version in the body is missing or
|
"""Check that we get a 400 if the version in the body is missing or
|
||||||
doesn't match
|
doesn't match
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
yield self.handler.update_version(self.local_user, version, {
|
yield self.handler.update_version(
|
||||||
|
self.local_user,
|
||||||
|
version,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "revised_first_version_auth_data"
|
"auth_data": "revised_first_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
except errors.SynapseError as e:
|
except errors.SynapseError as e:
|
||||||
res = e.code
|
res = e.code
|
||||||
self.assertEqual(res, 400)
|
self.assertEqual(res, 400)
|
||||||
|
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
yield self.handler.update_version(self.local_user, version, {
|
yield self.handler.update_version(
|
||||||
|
self.local_user,
|
||||||
|
version,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "revised_first_version_auth_data",
|
"auth_data": "revised_first_version_auth_data",
|
||||||
"version": "incorrect"
|
"version": "incorrect",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
except errors.SynapseError as e:
|
except errors.SynapseError as e:
|
||||||
res = e.code
|
res = e.code
|
||||||
self.assertEqual(res, 400)
|
self.assertEqual(res, 400)
|
||||||
|
@ -223,10 +252,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_delete_version(self):
|
def test_delete_version(self):
|
||||||
"""Check that we can create and then delete versions.
|
"""Check that we can create and then delete versions.
|
||||||
"""
|
"""
|
||||||
res = yield self.handler.create_version(self.local_user, {
|
res = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(res, "1")
|
self.assertEqual(res, "1")
|
||||||
|
|
||||||
# check we can delete it
|
# check we can delete it
|
||||||
|
@ -255,16 +284,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_get_missing_room_keys(self):
|
def test_get_missing_room_keys(self):
|
||||||
"""Check we get an empty response from an empty backup
|
"""Check we get an empty response from an empty backup
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"rooms": {}})
|
||||||
"rooms": {}
|
|
||||||
})
|
|
||||||
|
|
||||||
# TODO: test the locking semantics when uploading room_keys,
|
# TODO: test the locking semantics when uploading room_keys,
|
||||||
# although this is probably best done in sytest
|
# although this is probably best done in sytest
|
||||||
|
@ -275,7 +302,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
res = None
|
res = None
|
||||||
try:
|
try:
|
||||||
yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
|
yield self.handler.upload_room_keys(
|
||||||
|
self.local_user, "no_version", room_keys
|
||||||
|
)
|
||||||
except errors.SynapseError as e:
|
except errors.SynapseError as e:
|
||||||
res = e.code
|
res = e.code
|
||||||
self.assertEqual(res, 404)
|
self.assertEqual(res, 404)
|
||||||
|
@ -285,10 +314,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
"""Check that we get a 404 on uploading keys when an nonexistent version
|
"""Check that we get a 404 on uploading keys when an nonexistent version
|
||||||
is specified
|
is specified
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
res = None
|
res = None
|
||||||
|
@ -304,16 +333,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_upload_room_keys_wrong_version(self):
|
def test_upload_room_keys_wrong_version(self):
|
||||||
"""Check that we get a 403 on uploading keys for an old version
|
"""Check that we get a 403 on uploading keys for an old version
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
|
self.local_user,
|
||||||
|
{
|
||||||
"algorithm": "m.megolm_backup.v1",
|
"algorithm": "m.megolm_backup.v1",
|
||||||
"auth_data": "second_version_auth_data",
|
"auth_data": "second_version_auth_data",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
self.assertEqual(version, "2")
|
self.assertEqual(version, "2")
|
||||||
|
|
||||||
res = None
|
res = None
|
||||||
|
@ -327,10 +359,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_upload_room_keys_insert(self):
|
def test_upload_room_keys_insert(self):
|
||||||
"""Check that we can insert and retrieve keys for a session
|
"""Check that we can insert and retrieve keys for a session
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||||
|
@ -340,18 +372,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# check getting room_keys for a given room
|
# check getting room_keys for a given room
|
||||||
res = yield self.handler.get_room_keys(
|
res = yield self.handler.get_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org"
|
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, room_keys)
|
self.assertDictEqual(res, room_keys)
|
||||||
|
|
||||||
# check getting room_keys for a given session_id
|
# check getting room_keys for a given session_id
|
||||||
res = yield self.handler.get_room_keys(
|
res = yield self.handler.get_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
session_id="c0ff33",
|
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, room_keys)
|
self.assertDictEqual(res, room_keys)
|
||||||
|
|
||||||
|
@ -359,10 +386,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_upload_room_keys_merge(self):
|
def test_upload_room_keys_merge(self):
|
||||||
"""Check that we can upload a new room_key for an existing session and
|
"""Check that we can upload a new room_key for an existing session and
|
||||||
have it correctly merged"""
|
have it correctly merged"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||||
|
@ -378,7 +405,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
|
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
|
||||||
"SSBBTSBBIEZJU0gK"
|
"SSBBTSBBIEZJU0gK",
|
||||||
)
|
)
|
||||||
|
|
||||||
# test that marking the session as verified however /does/ replace it
|
# test that marking the session as verified however /does/ replace it
|
||||||
|
@ -387,8 +414,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
|
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
|
||||||
"new"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# test that a session with a higher forwarded_count doesn't replace one
|
# test that a session with a higher forwarded_count doesn't replace one
|
||||||
|
@ -399,8 +425,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
|
res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
|
||||||
"new"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: check edge cases as well as the common variations here
|
# TODO: check edge cases as well as the common variations here
|
||||||
|
@ -409,56 +434,36 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||||
def test_delete_room_keys(self):
|
def test_delete_room_keys(self):
|
||||||
"""Check that we can insert and delete keys for a session
|
"""Check that we can insert and delete keys for a session
|
||||||
"""
|
"""
|
||||||
version = yield self.handler.create_version(self.local_user, {
|
version = yield self.handler.create_version(
|
||||||
"algorithm": "m.megolm_backup.v1",
|
self.local_user,
|
||||||
"auth_data": "first_version_auth_data",
|
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||||
})
|
)
|
||||||
self.assertEqual(version, "1")
|
self.assertEqual(version, "1")
|
||||||
|
|
||||||
# check for bulk-delete
|
# check for bulk-delete
|
||||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||||
yield self.handler.delete_room_keys(self.local_user, version)
|
yield self.handler.delete_room_keys(self.local_user, version)
|
||||||
res = yield self.handler.get_room_keys(
|
res = yield self.handler.get_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
session_id="c0ff33",
|
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"rooms": {}})
|
||||||
"rooms": {}
|
|
||||||
})
|
|
||||||
|
|
||||||
# check for bulk-delete per room
|
# check for bulk-delete per room
|
||||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||||
yield self.handler.delete_room_keys(
|
yield self.handler.delete_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
)
|
)
|
||||||
res = yield self.handler.get_room_keys(
|
res = yield self.handler.get_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
session_id="c0ff33",
|
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"rooms": {}})
|
||||||
"rooms": {}
|
|
||||||
})
|
|
||||||
|
|
||||||
# check for bulk-delete per session
|
# check for bulk-delete per session
|
||||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||||
yield self.handler.delete_room_keys(
|
yield self.handler.delete_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
session_id="c0ff33",
|
|
||||||
)
|
)
|
||||||
res = yield self.handler.get_room_keys(
|
res = yield self.handler.get_room_keys(
|
||||||
self.local_user,
|
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||||
version,
|
|
||||||
room_id="!abc:matrix.org",
|
|
||||||
session_id="c0ff33",
|
|
||||||
)
|
)
|
||||||
self.assertDictEqual(res, {
|
self.assertDictEqual(res, {"rooms": {}})
|
||||||
"rooms": {}
|
|
||||||
})
|
|
||||||
|
|
|
@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"server", http_client=None,
|
"server", http_client=None, federation_sender=Mock()
|
||||||
federation_sender=Mock(),
|
|
||||||
)
|
)
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -457,7 +456,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Mark test2 as online, test will be offline with a last_active of 0
|
# Mark test2 as online, test will be offline with a last_active of 0
|
||||||
self.presence_handler.set_state(
|
self.presence_handler.set_state(
|
||||||
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
|
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
|
||||||
)
|
)
|
||||||
self.reactor.pump([0]) # Wait for presence updates to be handled
|
self.reactor.pump([0]) # Wait for presence updates to be handled
|
||||||
|
|
||||||
|
@ -506,13 +505,13 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Mark test as online
|
# Mark test as online
|
||||||
self.presence_handler.set_state(
|
self.presence_handler.set_state(
|
||||||
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
|
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark test2 as online, test will be offline with a last_active of 0.
|
# Mark test2 as online, test will be offline with a last_active of 0.
|
||||||
# Note we don't join them to the room yet
|
# Note we don't join them to the room yet
|
||||||
self.presence_handler.set_state(
|
self.presence_handler.set_state(
|
||||||
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
|
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add servers to the room
|
# Add servers to the room
|
||||||
|
@ -541,8 +540,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
||||||
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
||||||
destinations=set(("server2", "server3")),
|
destinations=set(("server2", "server3")), states=[expected_state]
|
||||||
states=[expected_state]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_new_user(self, room_id, user_id):
|
def _add_new_user(self, room_id, user_id):
|
||||||
|
@ -565,7 +563,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
||||||
type=EventTypes.Member,
|
type=EventTypes.Member,
|
||||||
sender=user_id,
|
sender=user_id,
|
||||||
state_key=user_id,
|
state_key=user_id,
|
||||||
content={"membership": Membership.JOIN}
|
content={"membership": Membership.JOIN},
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_event_ids = self.get_success(
|
prev_event_ids = self.get_success(
|
||||||
|
|
|
@ -64,7 +64,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
datastore=(Mock(
|
datastore=(
|
||||||
|
Mock(
|
||||||
spec=[
|
spec=[
|
||||||
# Bits that Federation needs
|
# Bits that Federation needs
|
||||||
"prep_send_transaction",
|
"prep_send_transaction",
|
||||||
|
@ -77,7 +78,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
"get_user_directory_stream_pos",
|
"get_user_directory_stream_pos",
|
||||||
"get_current_state_deltas",
|
"get_current_state_deltas",
|
||||||
]
|
]
|
||||||
)),
|
)
|
||||||
|
),
|
||||||
notifier=Mock(),
|
notifier=Mock(),
|
||||||
http_client=mock_federation_client,
|
http_client=mock_federation_client,
|
||||||
keyring=mock_keyring,
|
keyring=mock_keyring,
|
||||||
|
@ -114,6 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
def check_joined_room(room_id, user_id):
|
def check_joined_room(room_id, user_id):
|
||||||
if user_id not in [u.to_string() for u in self.room_members]:
|
if user_id not in [u.to_string() for u in self.room_members]:
|
||||||
raise AuthError(401, "User is not in the room")
|
raise AuthError(401, "User is not in the room")
|
||||||
|
|
||||||
hs.get_auth().check_joined_room = check_joined_room
|
hs.get_auth().check_joined_room = check_joined_room
|
||||||
|
|
||||||
def get_joined_hosts_for_room(room_id):
|
def get_joined_hosts_for_room(room_id):
|
||||||
|
@ -123,6 +126,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def get_current_users_in_room(room_id):
|
def get_current_users_in_room(room_id):
|
||||||
return set(str(u) for u in self.room_members)
|
return set(str(u) for u in self.room_members)
|
||||||
|
|
||||||
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
|
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
|
||||||
|
|
||||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||||
|
@ -141,21 +145,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
self.successResultOf(self.handler.started_typing(
|
self.successResultOf(
|
||||||
target_user=U_APPLE,
|
self.handler.started_typing(
|
||||||
auth_user=U_APPLE,
|
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
|
||||||
room_id=ROOM_ID,
|
|
||||||
timeout=20000,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
|
||||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
||||||
room_ids=[ROOM_ID], from_key=0
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -170,12 +169,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
def test_started_typing_remote_send(self):
|
def test_started_typing_remote_send(self):
|
||||||
self.room_members = [U_APPLE, U_ONION]
|
self.room_members = [U_APPLE, U_ONION]
|
||||||
|
|
||||||
self.successResultOf(self.handler.started_typing(
|
self.successResultOf(
|
||||||
target_user=U_APPLE,
|
self.handler.started_typing(
|
||||||
auth_user=U_APPLE,
|
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
|
||||||
room_id=ROOM_ID,
|
)
|
||||||
timeout=20000,
|
)
|
||||||
))
|
|
||||||
|
|
||||||
put_json = self.hs.get_http_client().put_json
|
put_json = self.hs.get_http_client().put_json
|
||||||
put_json.assert_called_once_with(
|
put_json.assert_called_once_with(
|
||||||
|
@ -216,14 +214,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
|
||||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
||||||
room_ids=[ROOM_ID], from_key=0
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -247,13 +241,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
self.successResultOf(self.handler.stopped_typing(
|
self.successResultOf(
|
||||||
|
self.handler.stopped_typing(
|
||||||
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
|
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
|
||||||
))
|
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
|
||||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
|
||||||
|
|
||||||
put_json = self.hs.get_http_client().put_json
|
put_json = self.hs.get_http_client().put_json
|
||||||
put_json.assert_called_once_with(
|
put_json.assert_called_once_with(
|
||||||
|
@ -274,18 +268,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
||||||
room_ids=[ROOM_ID], from_key=0
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
|
||||||
{
|
|
||||||
"type": "m.typing",
|
|
||||||
"room_id": ROOM_ID,
|
|
||||||
"content": {"user_ids": []},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_typing_timeout(self):
|
def test_typing_timeout(self):
|
||||||
|
@ -293,22 +279,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
self.successResultOf(self.handler.started_typing(
|
self.successResultOf(
|
||||||
target_user=U_APPLE,
|
self.handler.started_typing(
|
||||||
auth_user=U_APPLE,
|
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
|
||||||
room_id=ROOM_ID,
|
|
||||||
timeout=10000,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
|
||||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
|
||||||
self.on_new_event.reset_mock()
|
self.on_new_event.reset_mock()
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
||||||
room_ids=[ROOM_ID], from_key=0
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -320,45 +301,30 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.reactor.pump([16, ])
|
self.reactor.pump([16])
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])])
|
||||||
[call('typing_key', 2, rooms=[ROOM_ID])]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
|
||||||
room_ids=[ROOM_ID], from_key=1
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
|
||||||
{
|
|
||||||
"type": "m.typing",
|
|
||||||
"room_id": ROOM_ID,
|
|
||||||
"content": {"user_ids": []},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# SYN-230 - see if we can still set after timeout
|
# SYN-230 - see if we can still set after timeout
|
||||||
|
|
||||||
self.successResultOf(self.handler.started_typing(
|
self.successResultOf(
|
||||||
target_user=U_APPLE,
|
self.handler.started_typing(
|
||||||
auth_user=U_APPLE,
|
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
|
||||||
room_id=ROOM_ID,
|
|
||||||
timeout=10000,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls(
|
|
||||||
[call('typing_key', 3, rooms=[ROOM_ID])]
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])])
|
||||||
self.on_new_event.reset_mock()
|
self.on_new_event.reset_mock()
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||||
events = self.event_source.get_new_events(
|
events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
|
||||||
room_ids=[ROOM_ID], from_key=0
|
|
||||||
)
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
|
|
@ -352,9 +352,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Assert user directory is not empty
|
# Assert user directory is not empty
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"POST",
|
"POST", b"user_directory/search", b'{"search_term":"user2"}'
|
||||||
b"user_directory/search",
|
|
||||||
b'{"search_term":"user2"}',
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
@ -363,9 +361,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
||||||
# Disable user directory and check search returns nothing
|
# Disable user directory and check search returns nothing
|
||||||
self.config.user_directory_search_enabled = False
|
self.config.user_directory_search_enabled = False
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"POST",
|
"POST", b"user_directory/search", b'{"search_term":"user2"}'
|
||||||
b"user_directory/search",
|
|
||||||
b'{"search_term":"user2"}',
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(200, channel.code, channel.result)
|
self.assertEquals(200, channel.code, channel.result)
|
||||||
|
|
|
@ -24,14 +24,12 @@ def get_test_cert_file():
|
||||||
#
|
#
|
||||||
# openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
|
# openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \
|
||||||
# -nodes -subj '/CN=testserv'
|
# -nodes -subj '/CN=testserv'
|
||||||
return os.path.join(
|
return os.path.join(os.path.dirname(__file__), 'server.pem')
|
||||||
os.path.dirname(__file__),
|
|
||||||
'server.pem',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ServerTLSContext(object):
|
class ServerTLSContext(object):
|
||||||
"""A TLS Context which presents our test cert."""
|
"""A TLS Context which presents our test cert."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.filename = get_test_cert_file()
|
self.filename = get_test_cert_file()
|
||||||
|
|
||||||
|
|
|
@ -79,12 +79,12 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
# stubbing that out here.
|
# stubbing that out here.
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
client_protocol.makeConnection(
|
client_protocol.makeConnection(
|
||||||
FakeTransport(server_tls_protocol, self.reactor, client_protocol),
|
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
|
||||||
)
|
)
|
||||||
|
|
||||||
# tell the server tls protocol to send its stuff back to the client, too
|
# tell the server tls protocol to send its stuff back to the client, too
|
||||||
server_tls_protocol.makeConnection(
|
server_tls_protocol.makeConnection(
|
||||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol),
|
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||||
)
|
)
|
||||||
|
|
||||||
# give the reactor a pump to get the TLS juices flowing.
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
|
@ -125,7 +125,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
_check_logcontext(context)
|
_check_logcontext(context)
|
||||||
|
|
||||||
def _handle_well_known_connection(
|
def _handle_well_known_connection(
|
||||||
self, client_factory, expected_sni, content, response_headers={},
|
self, client_factory, expected_sni, content, response_headers={}
|
||||||
):
|
):
|
||||||
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
|
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
|
||||||
request is for a .well-known, and send the response.
|
request is for a .well-known, and send the response.
|
||||||
|
@ -139,8 +139,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
"""
|
"""
|
||||||
# make the connection for .well-known
|
# make the connection for .well-known
|
||||||
well_known_server = self._make_connection(
|
well_known_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=expected_sni
|
||||||
expected_sni=expected_sni,
|
|
||||||
)
|
)
|
||||||
# check the .well-known request and send a response
|
# check the .well-known request and send a response
|
||||||
self.assertEqual(len(well_known_server.requests), 1)
|
self.assertEqual(len(well_known_server.requests), 1)
|
||||||
|
@ -154,10 +153,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
"""
|
"""
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/.well-known/matrix/server')
|
self.assertEqual(request.path, b'/.well-known/matrix/server')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'testserv'],
|
|
||||||
)
|
|
||||||
# send back a response
|
# send back a response
|
||||||
for k, v in headers.items():
|
for k, v in headers.items():
|
||||||
request.setHeader(k, v)
|
request.setHeader(k, v)
|
||||||
|
@ -184,18 +180,14 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=b"testserv")
|
||||||
client_factory,
|
|
||||||
expected_sni=b"testserv",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448']
|
||||||
[b'testserv:8448']
|
|
||||||
)
|
)
|
||||||
content = request.content.read()
|
content = request.content.read()
|
||||||
self.assertEqual(content, b'')
|
self.assertEqual(content, b'')
|
||||||
|
@ -244,19 +236,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=None)
|
||||||
client_factory,
|
|
||||||
expected_sni=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'1.2.3.4'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -285,19 +271,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=None)
|
||||||
client_factory,
|
|
||||||
expected_sni=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'[::1]'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -326,19 +306,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 80)
|
self.assertEqual(port, 80)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=None)
|
||||||
client_factory,
|
|
||||||
expected_sni=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'[::1]:80'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -377,7 +351,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# now there should be a SRV lookup
|
# now there should be a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.testserv",
|
b"_matrix._tcp.testserv"
|
||||||
)
|
)
|
||||||
|
|
||||||
# we should fall back to a direct connection
|
# we should fall back to a direct connection
|
||||||
|
@ -387,19 +361,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=b'testserv')
|
||||||
client_factory,
|
|
||||||
expected_sni=b'testserv',
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'testserv'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -427,13 +395,14 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
self._handle_well_known_connection(
|
self._handle_well_known_connection(
|
||||||
client_factory, expected_sni=b"testserv",
|
client_factory,
|
||||||
|
expected_sni=b"testserv",
|
||||||
content=b'{ "m.server": "target-server" }',
|
content=b'{ "m.server": "target-server" }',
|
||||||
)
|
)
|
||||||
|
|
||||||
# there should be a SRV lookup
|
# there should be a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.target-server",
|
b"_matrix._tcp.target-server"
|
||||||
)
|
)
|
||||||
|
|
||||||
# now we should get a connection to the target server
|
# now we should get a connection to the target server
|
||||||
|
@ -444,8 +413,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b'target-server'
|
||||||
expected_sni=b'target-server',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
@ -453,8 +421,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
|
||||||
[b'target-server'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
|
@ -490,8 +457,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
redirect_server = self._make_connection(
|
redirect_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b"testserv"
|
||||||
expected_sni=b"testserv",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# send a 302 redirect
|
# send a 302 redirect
|
||||||
|
@ -510,8 +476,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
well_known_server = self._make_connection(
|
well_known_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b"testserv"
|
||||||
expected_sni=b"testserv",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
|
self.assertEqual(len(well_known_server.requests), 1, "No request after 302")
|
||||||
|
@ -525,7 +490,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# there should be a SRV lookup
|
# there should be a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.target-server",
|
b"_matrix._tcp.target-server"
|
||||||
)
|
)
|
||||||
|
|
||||||
# now we should get a connection to the target server
|
# now we should get a connection to the target server
|
||||||
|
@ -536,8 +501,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b'target-server'
|
||||||
expected_sni=b'target-server',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
@ -545,8 +509,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
|
||||||
[b'target-server'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
|
@ -585,12 +548,12 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
self._handle_well_known_connection(
|
self._handle_well_known_connection(
|
||||||
client_factory, expected_sni=b"testserv", content=b'NOT JSON',
|
client_factory, expected_sni=b"testserv", content=b'NOT JSON'
|
||||||
)
|
)
|
||||||
|
|
||||||
# now there should be a SRV lookup
|
# now there should be a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.testserv",
|
b"_matrix._tcp.testserv"
|
||||||
)
|
)
|
||||||
|
|
||||||
# we should fall back to a direct connection
|
# we should fall back to a direct connection
|
||||||
|
@ -600,19 +563,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8448)
|
self.assertEqual(port, 8448)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=b'testserv')
|
||||||
client_factory,
|
|
||||||
expected_sni=b'testserv',
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'testserv'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -635,7 +592,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# the request for a .well-known will have failed with a DNS lookup error.
|
# the request for a .well-known will have failed with a DNS lookup error.
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.testserv",
|
b"_matrix._tcp.testserv"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure treq is trying to connect
|
# Make sure treq is trying to connect
|
||||||
|
@ -646,19 +603,13 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 8443)
|
self.assertEqual(port, 8443)
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(client_factory, expected_sni=b'testserv')
|
||||||
client_factory,
|
|
||||||
expected_sni=b'testserv',
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
request = http_server.requests[0]
|
request = http_server.requests[0]
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv'])
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
|
||||||
[b'testserv'],
|
|
||||||
)
|
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
request.finish()
|
request.finish()
|
||||||
|
@ -685,17 +636,18 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(port, 443)
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
self.mock_resolver.resolve_service.side_effect = lambda _: [
|
||||||
Server(host=b"srvtarget", port=8443),
|
Server(host=b"srvtarget", port=8443)
|
||||||
]
|
]
|
||||||
|
|
||||||
self._handle_well_known_connection(
|
self._handle_well_known_connection(
|
||||||
client_factory, expected_sni=b"testserv",
|
client_factory,
|
||||||
|
expected_sni=b"testserv",
|
||||||
content=b'{ "m.server": "target-server" }',
|
content=b'{ "m.server": "target-server" }',
|
||||||
)
|
)
|
||||||
|
|
||||||
# there should be a SRV lookup
|
# there should be a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.target-server",
|
b"_matrix._tcp.target-server"
|
||||||
)
|
)
|
||||||
|
|
||||||
# now we should get a connection to the target of the SRV record
|
# now we should get a connection to the target of the SRV record
|
||||||
|
@ -706,8 +658,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b'target-server'
|
||||||
expected_sni=b'target-server',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
@ -715,8 +666,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'target-server']
|
||||||
[b'target-server'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
|
@ -757,7 +707,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# now there should have been a SRV lookup
|
# now there should have been a SRV lookup
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.xn--bcher-kva.com",
|
b"_matrix._tcp.xn--bcher-kva.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should fall back to port 8448
|
# We should fall back to port 8448
|
||||||
|
@ -769,8 +719,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b'xn--bcher-kva.com'
|
||||||
expected_sni=b'xn--bcher-kva.com',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
@ -778,8 +727,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
|
||||||
[b'xn--bcher-kva.com'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
|
@ -801,7 +749,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertNoResult(test_d)
|
self.assertNoResult(test_d)
|
||||||
|
|
||||||
self.mock_resolver.resolve_service.assert_called_once_with(
|
self.mock_resolver.resolve_service.assert_called_once_with(
|
||||||
b"_matrix._tcp.xn--bcher-kva.com",
|
b"_matrix._tcp.xn--bcher-kva.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure treq is trying to connect
|
# Make sure treq is trying to connect
|
||||||
|
@ -813,8 +761,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
|
|
||||||
# make a test server, and wire up the client
|
# make a test server, and wire up the client
|
||||||
http_server = self._make_connection(
|
http_server = self._make_connection(
|
||||||
client_factory,
|
client_factory, expected_sni=b'xn--bcher-kva.com'
|
||||||
expected_sni=b'xn--bcher-kva.com',
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(http_server.requests), 1)
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
@ -822,8 +769,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||||
self.assertEqual(request.method, b'GET')
|
self.assertEqual(request.method, b'GET')
|
||||||
self.assertEqual(request.path, b'/foo/bar')
|
self.assertEqual(request.path, b'/foo/bar')
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
request.requestHeaders.getRawHeaders(b'host'),
|
request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com']
|
||||||
[b'xn--bcher-kva.com'],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# finish the request
|
# finish the request
|
||||||
|
@ -897,67 +843,70 @@ class TestCachePeriodFromHeaders(TestCase):
|
||||||
# uppercase
|
# uppercase
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
|
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']})
|
||||||
), 100,
|
),
|
||||||
|
100,
|
||||||
)
|
)
|
||||||
|
|
||||||
# missing value
|
# missing value
|
||||||
self.assertIsNone(_cache_period_from_headers(
|
self.assertIsNone(
|
||||||
Headers({b'Cache-Control': [b'max-age=, bar']}),
|
_cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']}))
|
||||||
))
|
)
|
||||||
|
|
||||||
# hackernews: bogus due to semicolon
|
# hackernews: bogus due to semicolon
|
||||||
self.assertIsNone(_cache_period_from_headers(
|
self.assertIsNone(
|
||||||
Headers({b'Cache-Control': [b'private; max-age=0']}),
|
_cache_period_from_headers(
|
||||||
))
|
Headers({b'Cache-Control': [b'private; max-age=0']})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# github
|
# github
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
|
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']})
|
||||||
), 0,
|
),
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# google
|
# google
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({b'cache-control': [b'private, max-age=0']}),
|
Headers({b'cache-control': [b'private, max-age=0']})
|
||||||
), 0,
|
),
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_expires(self):
|
def test_expires(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
|
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
|
||||||
time_now=lambda: 1548833700
|
time_now=lambda: 1548833700,
|
||||||
), 33,
|
),
|
||||||
|
33,
|
||||||
)
|
)
|
||||||
|
|
||||||
# cache-control overrides expires
|
# cache-control overrides expires
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
_cache_period_from_headers(
|
_cache_period_from_headers(
|
||||||
Headers({
|
Headers(
|
||||||
|
{
|
||||||
b'cache-control': [b'max-age=10'],
|
b'cache-control': [b'max-age=10'],
|
||||||
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
|
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'],
|
||||||
}),
|
}
|
||||||
time_now=lambda: 1548833700
|
),
|
||||||
), 10,
|
time_now=lambda: 1548833700,
|
||||||
|
),
|
||||||
|
10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# invalid expires means immediate expiry
|
# invalid expires means immediate expiry
|
||||||
self.assertEqual(
|
self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0)
|
||||||
_cache_period_from_headers(
|
|
||||||
Headers({b'Expires': [b'0']}),
|
|
||||||
), 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_logcontext(context):
|
def _check_logcontext(context):
|
||||||
current = LoggingContext.current_context()
|
current = LoggingContext.current_context()
|
||||||
if current is not context:
|
if current is not context:
|
||||||
raise AssertionError(
|
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
||||||
"Expected logcontext %s but was %s" % (context, current),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_test_server():
|
def _build_test_server():
|
||||||
|
@ -973,7 +922,7 @@ def _build_test_server():
|
||||||
server_factory.log = _log_request
|
server_factory.log = _log_request
|
||||||
|
|
||||||
server_tls_factory = TLSMemoryBIOFactory(
|
server_tls_factory = TLSMemoryBIOFactory(
|
||||||
ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
|
ServerTLSContext(), isClient=False, wrappedFactory=server_factory
|
||||||
)
|
)
|
||||||
|
|
||||||
return server_tls_factory.buildProtocol(None)
|
return server_tls_factory.buildProtocol(None)
|
||||||
|
@ -987,6 +936,7 @@ def _log_request(request):
|
||||||
@implementer(IPolicyForHTTPS)
|
@implementer(IPolicyForHTTPS)
|
||||||
class TrustingTLSPolicyForHTTPS(object):
|
class TrustingTLSPolicyForHTTPS(object):
|
||||||
"""An IPolicyForHTTPS which doesn't do any certificate verification"""
|
"""An IPolicyForHTTPS which doesn't do any certificate verification"""
|
||||||
|
|
||||||
def creatorForNetloc(self, hostname, port):
|
def creatorForNetloc(self, hostname, port):
|
||||||
certificateOptions = OpenSSLCertificateOptions()
|
certificateOptions = OpenSSLCertificateOptions()
|
||||||
return ClientTLSOptions(hostname, certificateOptions.getContext())
|
return ClientTLSOptions(hostname, certificateOptions.getContext())
|
||||||
|
|
|
@ -68,9 +68,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
result_deferred.callback(
|
result_deferred.callback(([answer_srv], None, None))
|
||||||
([answer_srv], None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
servers = self.successResultOf(test_d)
|
servers = self.successResultOf(test_d)
|
||||||
|
|
||||||
|
@ -112,7 +110,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
resolver = SrvResolver(
|
resolver = SrvResolver(
|
||||||
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time
|
||||||
)
|
)
|
||||||
|
|
||||||
servers = yield resolver.resolve_service(service_name)
|
servers = yield resolver.resolve_service(service_name)
|
||||||
|
@ -168,11 +166,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
# returning a single "." should make the lookup fail with a ConenctError
|
# returning a single "." should make the lookup fail with a ConenctError
|
||||||
lookup_deferred.callback((
|
lookup_deferred.callback(
|
||||||
|
(
|
||||||
[dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
|
[dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.failureResultOf(resolve_d, ConnectError)
|
self.failureResultOf(resolve_d, ConnectError)
|
||||||
|
|
||||||
|
@ -191,14 +191,16 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
resolve_d = resolver.resolve_service(service_name)
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
lookup_deferred.callback((
|
lookup_deferred.callback(
|
||||||
|
(
|
||||||
[
|
[
|
||||||
dns.RRHeader(type=dns.A, payload=dns.Record_A()),
|
dns.RRHeader(type=dns.A, payload=dns.Record_A()),
|
||||||
dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
|
dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
|
||||||
],
|
],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
servers = self.successResultOf(resolve_d)
|
servers = self.successResultOf(resolve_d)
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,7 @@ from tests.unittest import HomeserverTestCase
|
||||||
def check_logcontext(context):
|
def check_logcontext(context):
|
||||||
current = LoggingContext.current_context()
|
current = LoggingContext.current_context()
|
||||||
if current is not context:
|
if current is not context:
|
||||||
raise AssertionError(
|
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
|
||||||
"Expected logcontext %s but was %s" % (context, current),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FederationClientTests(HomeserverTestCase):
|
class FederationClientTests(HomeserverTestCase):
|
||||||
|
@ -54,6 +52,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
happy-path test of a GET request
|
happy-path test of a GET request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_request():
|
def do_request():
|
||||||
with LoggingContext("one") as context:
|
with LoggingContext("one") as context:
|
||||||
|
@ -175,8 +174,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
|
|
||||||
self.assertIsInstance(f.value, RequestSendFailed)
|
self.assertIsInstance(f.value, RequestSendFailed)
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
f.value.inner_exception,
|
f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
|
||||||
(ConnectingCancelledError, TimeoutError),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_client_connect_no_response(self):
|
def test_client_connect_no_response(self):
|
||||||
|
@ -216,9 +214,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
Once the client gets the headers, _request returns successfully.
|
Once the client gets the headers, _request returns successfully.
|
||||||
"""
|
"""
|
||||||
request = MatrixFederationRequest(
|
request = MatrixFederationRequest(
|
||||||
method="GET",
|
method="GET", destination="testserv:8008", path="foo/bar"
|
||||||
destination="testserv:8008",
|
|
||||||
path="foo/bar",
|
|
||||||
)
|
)
|
||||||
d = self.cl._send_request(request, timeout=10000)
|
d = self.cl._send_request(request, timeout=10000)
|
||||||
|
|
||||||
|
@ -258,8 +254,10 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
|
|
||||||
# Send it the HTTP response
|
# Send it the HTTP response
|
||||||
client.dataReceived(
|
client.dataReceived(
|
||||||
(b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
|
(
|
||||||
b"Server: Fake\r\n\r\n")
|
b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
|
||||||
|
b"Server: Fake\r\n\r\n"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Push by enough to time it out
|
# Push by enough to time it out
|
||||||
|
@ -274,9 +272,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
requiring a trailing slash. We need to retry the request with a
|
requiring a trailing slash. We need to retry the request with a
|
||||||
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
|
trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
|
||||||
"""
|
"""
|
||||||
d = self.cl.get_json(
|
d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
|
||||||
"testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the request
|
# Send the request
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -329,9 +325,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
|
|
||||||
See test_client_requires_trailing_slashes() for context.
|
See test_client_requires_trailing_slashes() for context.
|
||||||
"""
|
"""
|
||||||
d = self.cl.get_json(
|
d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
|
||||||
"testserv:8008", "foo/bar", try_trailing_slash_on_400=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send the request
|
# Send the request
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -368,10 +362,7 @@ class FederationClientTests(HomeserverTestCase):
|
||||||
self.failureResultOf(d)
|
self.failureResultOf(d)
|
||||||
|
|
||||||
def test_client_sends_body(self):
|
def test_client_sends_body(self):
|
||||||
self.cl.post_json(
|
self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
|
||||||
"testserv:8008", "foo/bar", timeout=10000,
|
|
||||||
data={"a": "b"}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,9 @@ def do_patch():
|
||||||
except Exception:
|
except Exception:
|
||||||
if LoggingContext.current_context() != start_context:
|
if LoggingContext.current_context() != start_context:
|
||||||
err = "%s changed context from %s to %s on exception" % (
|
err = "%s changed context from %s to %s on exception" % (
|
||||||
f, start_context, LoggingContext.current_context()
|
f,
|
||||||
|
start_context,
|
||||||
|
LoggingContext.current_context(),
|
||||||
)
|
)
|
||||||
print(err, file=sys.stderr)
|
print(err, file=sys.stderr)
|
||||||
raise Exception(err)
|
raise Exception(err)
|
||||||
|
@ -54,7 +56,9 @@ def do_patch():
|
||||||
if not isinstance(res, Deferred) or res.called:
|
if not isinstance(res, Deferred) or res.called:
|
||||||
if LoggingContext.current_context() != start_context:
|
if LoggingContext.current_context() != start_context:
|
||||||
err = "%s changed context from %s to %s" % (
|
err = "%s changed context from %s to %s" % (
|
||||||
f, start_context, LoggingContext.current_context()
|
f,
|
||||||
|
start_context,
|
||||||
|
LoggingContext.current_context(),
|
||||||
)
|
)
|
||||||
# print the error to stderr because otherwise all we
|
# print the error to stderr because otherwise all we
|
||||||
# see in travis-ci is the 500 error
|
# see in travis-ci is the 500 error
|
||||||
|
@ -66,9 +70,7 @@ def do_patch():
|
||||||
err = (
|
err = (
|
||||||
"%s returned incomplete deferred in non-sentinel context "
|
"%s returned incomplete deferred in non-sentinel context "
|
||||||
"%s (start was %s)"
|
"%s (start was %s)"
|
||||||
) % (
|
) % (f, LoggingContext.current_context(), start_context)
|
||||||
f, LoggingContext.current_context(), start_context,
|
|
||||||
)
|
|
||||||
print(err, file=sys.stderr)
|
print(err, file=sys.stderr)
|
||||||
raise Exception(err)
|
raise Exception(err)
|
||||||
|
|
||||||
|
@ -76,7 +78,9 @@ def do_patch():
|
||||||
if LoggingContext.current_context() != start_context:
|
if LoggingContext.current_context() != start_context:
|
||||||
err = "%s completion of %s changed context from %s to %s" % (
|
err = "%s completion of %s changed context from %s to %s" % (
|
||||||
"Failure" if isinstance(r, Failure) else "Success",
|
"Failure" if isinstance(r, Failure) else "Success",
|
||||||
f, start_context, LoggingContext.current_context(),
|
f,
|
||||||
|
start_context,
|
||||||
|
LoggingContext.current_context(),
|
||||||
)
|
)
|
||||||
print(err, file=sys.stderr)
|
print(err, file=sys.stderr)
|
||||||
raise Exception(err)
|
raise Exception(err)
|
||||||
|
|
|
@ -74,21 +74,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
master_result,
|
master_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
"Expected master result to be %r but was %r" % (
|
"Expected master result to be %r but was %r"
|
||||||
expected_result, master_result
|
% (expected_result, master_result),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
slaved_result,
|
slaved_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
"Expected slave result to be %r but was %r" % (
|
"Expected slave result to be %r but was %r"
|
||||||
expected_result, slaved_result
|
% (expected_result, slaved_result),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
master_result,
|
master_result,
|
||||||
slaved_result,
|
slaved_result,
|
||||||
"Slave result %r does not match master result %r" % (
|
"Slave result %r does not match master result %r"
|
||||||
slaved_result, master_result
|
% (slaved_result, master_result),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -234,10 +234,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
|
||||||
)
|
)
|
||||||
msg, msgctx = self.build_event()
|
msg, msgctx = self.build_event()
|
||||||
self.get_success(self.master_store.persist_events([
|
self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
|
||||||
(j2, j2ctx),
|
|
||||||
(msg, msgctx),
|
|
||||||
]))
|
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
event_source = RoomEventSource(self.hs)
|
event_source = RoomEventSource(self.hs)
|
||||||
|
@ -257,15 +254,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
#
|
#
|
||||||
# First, we get a list of the rooms we are joined to
|
# First, we get a list of the rooms we are joined to
|
||||||
joined_rooms = self.get_success(
|
joined_rooms = self.get_success(
|
||||||
self.slaved_store.get_rooms_for_user_with_stream_ordering(
|
self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
|
||||||
USER_ID_2,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then, we get a list of the events since the last sync
|
# Then, we get a list of the events since the last sync
|
||||||
membership_changes = self.get_success(
|
membership_changes = self.get_success(
|
||||||
self.slaved_store.get_membership_changes_for_user(
|
self.slaved_store.get_membership_changes_for_user(
|
||||||
USER_ID_2, prev_token, current_token,
|
USER_ID_2, prev_token, current_token
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -298,9 +293,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
self.master_store.persist_events([(event, context)], backfilled=True)
|
self.master_store.persist_events([(event, context)], backfilled=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.get_success(
|
self.get_success(self.master_store.persist_event(event, context))
|
||||||
self.master_store.persist_event(event, context)
|
|
||||||
)
|
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -359,9 +352,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_handler = self.hs.get_state_handler()
|
state_handler = self.hs.get_state_handler()
|
||||||
context = self.get_success(state_handler.compute_event_context(
|
context = self.get_success(state_handler.compute_event_context(event))
|
||||||
event
|
|
||||||
))
|
|
||||||
|
|
||||||
self.master_store.add_push_actions_to_staging(
|
self.master_store.add_push_actions_to_staging(
|
||||||
event.event_id, {user_id: actions for user_id, actions in push_actions}
|
event.event_id, {user_id: actions for user_id, actions in push_actions}
|
||||||
|
|
|
@ -22,6 +22,7 @@ from tests.server import FakeTransport
|
||||||
|
|
||||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
"""Base class for tests of the replication streams"""
|
"""Base class for tests of the replication streams"""
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
# build a replication server
|
# build a replication server
|
||||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||||
|
@ -52,6 +53,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
class TestReplicationClientHandler(object):
|
class TestReplicationClientHandler(object):
|
||||||
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
|
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.received_rdata_rows = []
|
self.received_rdata_rows = []
|
||||||
|
|
||||||
|
@ -69,6 +71,4 @@ class TestReplicationClientHandler(object):
|
||||||
|
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append(
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
(stream_name, token, r)
|
|
||||||
)
|
|
||||||
|
|
|
@ -41,10 +41,10 @@ class VersionTestCase(unittest.HomeserverTestCase):
|
||||||
request, channel = self.make_request("GET", self.url, shorthand=False)
|
request, channel = self.make_request("GET", self.url, shorthand=False)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]),
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
msg=channel.result["body"])
|
self.assertEqual(
|
||||||
self.assertEqual({'server_version', 'python_version'},
|
{'server_version', 'python_version'}, set(channel.json_body.keys())
|
||||||
set(channel.json_body.keys()))
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserRegisterTestCase(unittest.HomeserverTestCase):
|
class UserRegisterTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -200,9 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
|
||||||
nonce = channel.json_body["nonce"]
|
nonce = channel.json_body["nonce"]
|
||||||
|
|
||||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||||
want_mac.update(
|
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
|
||||||
nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin"
|
|
||||||
)
|
|
||||||
want_mac = want_mac.hexdigest()
|
want_mac = want_mac.hexdigest()
|
||||||
|
|
||||||
body = json.dumps(
|
body = json.dumps(
|
||||||
|
@ -330,11 +328,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
|
||||||
#
|
#
|
||||||
|
|
||||||
# Invalid user_type
|
# Invalid user_type
|
||||||
body = json.dumps({
|
body = json.dumps(
|
||||||
|
{
|
||||||
"nonce": nonce(),
|
"nonce": nonce(),
|
||||||
"username": "a",
|
"username": "a",
|
||||||
"password": "1234",
|
"password": "1234",
|
||||||
"user_type": "invalid"}
|
"user_type": "invalid",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
@ -357,9 +357,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
hs.config.user_consent_version = "1"
|
hs.config.user_consent_version = "1"
|
||||||
|
|
||||||
consent_uri_builder = Mock()
|
consent_uri_builder = Mock()
|
||||||
consent_uri_builder.build_user_consent_uri.return_value = (
|
consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
|
||||||
"http://example.com"
|
|
||||||
)
|
|
||||||
self.event_creation_handler._consent_uri_builder = consent_uri_builder
|
self.event_creation_handler._consent_uri_builder = consent_uri_builder
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -371,9 +369,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
self.other_user_token = self.login("user", "pass")
|
self.other_user_token = self.login("user", "pass")
|
||||||
|
|
||||||
# Mark the admin user as having consented
|
# Mark the admin user as having consented
|
||||||
self.get_success(
|
self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
|
||||||
self.store.user_set_consent_version(self.admin_user, "1"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_shutdown_room_consent(self):
|
def test_shutdown_room_consent(self):
|
||||||
"""Test that we can shutdown rooms with local users who have not
|
"""Test that we can shutdown rooms with local users who have not
|
||||||
|
@ -385,9 +381,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
|
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
|
||||||
|
|
||||||
# Assert one user in room
|
# Assert one user in room
|
||||||
users_in_room = self.get_success(
|
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
|
||||||
self.store.get_users_in_room(room_id),
|
|
||||||
)
|
|
||||||
self.assertEqual([self.other_user], users_in_room)
|
self.assertEqual([self.other_user], users_in_room)
|
||||||
|
|
||||||
# Enable require consent to send events
|
# Enable require consent to send events
|
||||||
|
@ -395,8 +389,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Assert that the user is getting consent error
|
# Assert that the user is getting consent error
|
||||||
self.helper.send(
|
self.helper.send(
|
||||||
room_id,
|
room_id, body="foo", tok=self.other_user_token, expect_code=403
|
||||||
body="foo", tok=self.other_user_token, expect_code=403,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test that the admin can still send shutdown
|
# Test that the admin can still send shutdown
|
||||||
|
@ -412,9 +405,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
# Assert there is now no longer anyone in the room
|
# Assert there is now no longer anyone in the room
|
||||||
users_in_room = self.get_success(
|
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
|
||||||
self.store.get_users_in_room(room_id),
|
|
||||||
)
|
|
||||||
self.assertEqual([], users_in_room)
|
self.assertEqual([], users_in_room)
|
||||||
|
|
||||||
@unittest.DEBUG
|
@unittest.DEBUG
|
||||||
|
@ -459,24 +450,20 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
url = "rooms/%s/initialSync" % (room_id,)
|
url = "rooms/%s/initialSync" % (room_id,)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", url.encode('ascii'), access_token=self.admin_user_tok
|
||||||
url.encode('ascii'),
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expect_code, int(channel.result["code"]), msg=channel.result["body"],
|
expect_code, int(channel.result["code"]), msg=channel.result["body"]
|
||||||
)
|
)
|
||||||
|
|
||||||
url = "events?timeout=0&room_id=" + room_id
|
url = "events?timeout=0&room_id=" + room_id
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", url.encode('ascii'), access_token=self.admin_user_tok
|
||||||
url.encode('ascii'),
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expect_code, int(channel.result["code"]), msg=channel.result["body"],
|
expect_code, int(channel.result["code"]), msg=channel.result["body"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -502,15 +489,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
|
||||||
"POST",
|
"POST",
|
||||||
"/create_group".encode('ascii'),
|
"/create_group".encode('ascii'),
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content={
|
content={"localpart": "test"},
|
||||||
"localpart": "test",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
200, int(channel.result["code"]), msg=channel.result["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
group_id = channel.json_body["group_id"]
|
group_id = channel.json_body["group_id"]
|
||||||
|
|
||||||
|
@ -520,27 +503,17 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
|
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT",
|
"PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={}
|
||||||
url.encode('ascii'),
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
content={}
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
200, int(channel.result["code"]), msg=channel.result["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
url = "/groups/%s/self/accept_invite" % (group_id,)
|
url = "/groups/%s/self/accept_invite" % (group_id,)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT",
|
"PUT", url.encode('ascii'), access_token=self.other_user_token, content={}
|
||||||
url.encode('ascii'),
|
|
||||||
access_token=self.other_user_token,
|
|
||||||
content={}
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
200, int(channel.result["code"]), msg=channel.result["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check other user knows they're in the group
|
# Check other user knows they're in the group
|
||||||
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
|
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
|
||||||
|
@ -552,15 +525,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
|
||||||
"POST",
|
"POST",
|
||||||
url.encode('ascii'),
|
url.encode('ascii'),
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content={
|
content={"localpart": "test"},
|
||||||
"localpart": "test",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
200, int(channel.result["code"]), msg=channel.result["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check group returns 404
|
# Check group returns 404
|
||||||
self._check_group(group_id, expect_code=404)
|
self._check_group(group_id, expect_code=404)
|
||||||
|
@ -576,28 +545,22 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
url = "/groups/%s/profile" % (group_id,)
|
url = "/groups/%s/profile" % (group_id,)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", url.encode('ascii'), access_token=self.admin_user_tok
|
||||||
url.encode('ascii'),
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expect_code, int(channel.result["code"]), msg=channel.result["body"],
|
expect_code, int(channel.result["code"]), msg=channel.result["body"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_groups_user_is_in(self, access_token):
|
def _get_groups_user_is_in(self, access_token):
|
||||||
"""Returns the list of groups the user is in (given their access token)
|
"""Returns the list of groups the user is in (given their access token)
|
||||||
"""
|
"""
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", "/joined_groups".encode('ascii'), access_token=access_token
|
||||||
"/joined_groups".encode('ascii'),
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
200, int(channel.result["code"]), msg=channel.result["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return channel.json_body["groups"]
|
return channel.json_body["groups"]
|
||||||
|
|
|
@ -44,7 +44,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
|
||||||
tok = self.login("kermit", "monkey")
|
tok = self.login("kermit", "monkey")
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", "/createRoom", b"{}", access_token=tok,
|
b"POST", "/createRoom", b"{}", access_token=tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
@ -56,11 +56,9 @@ class IdentityTestCase(unittest.HomeserverTestCase):
|
||||||
"address": "test@example.com",
|
"address": "test@example.com",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
request_url = (
|
request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii')
|
||||||
"/rooms/%s/invite" % (room_id)
|
|
||||||
).encode('ascii')
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", request_url, request_data, access_token=tok,
|
b"POST", request_url, request_data, access_token=tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
|
|
|
@ -45,7 +45,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.room_owner_tok = self.login("room_owner", "test")
|
self.room_owner_tok = self.login("room_owner", "test")
|
||||||
|
|
||||||
self.room_id = self.helper.create_room_as(
|
self.room_id = self.helper.create_room_as(
|
||||||
self.room_owner, tok=self.room_owner_tok,
|
self.room_owner, tok=self.room_owner_tok
|
||||||
)
|
)
|
||||||
|
|
||||||
self.user = self.register_user("user", "test")
|
self.user = self.register_user("user", "test")
|
||||||
|
@ -80,12 +80,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# We use deliberately a localpart under the length threshold so
|
# We use deliberately a localpart under the length threshold so
|
||||||
# that we can make sure that the check is done on the whole alias.
|
# that we can make sure that the check is done on the whole alias.
|
||||||
data = {
|
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
|
||||||
"room_alias_name": random_string(256 - len(self.hs.hostname)),
|
|
||||||
}
|
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"POST", url, request_data, access_token=self.user_tok,
|
"POST", url, request_data, access_token=self.user_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, 400, channel.result)
|
self.assertEqual(channel.code, 400, channel.result)
|
||||||
|
@ -96,51 +94,42 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
# Check with an alias of allowed length. There should already be
|
# Check with an alias of allowed length. There should already be
|
||||||
# a test that ensures it works in test_register.py, but let's be
|
# a test that ensures it works in test_register.py, but let's be
|
||||||
# as cautious as possible here.
|
# as cautious as possible here.
|
||||||
data = {
|
data = {"room_alias_name": random_string(5)}
|
||||||
"room_alias_name": random_string(5),
|
|
||||||
}
|
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"POST", url, request_data, access_token=self.user_tok,
|
"POST", url, request_data, access_token=self.user_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, 200, channel.result)
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
def set_alias_via_state_event(self, expected_code, alias_length=5):
|
def set_alias_via_state_event(self, expected_code, alias_length=5):
|
||||||
url = ("/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s"
|
url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
|
||||||
% (self.room_id, self.hs.hostname))
|
self.room_id,
|
||||||
|
self.hs.hostname,
|
||||||
|
)
|
||||||
|
|
||||||
data = {
|
data = {"aliases": [self.random_alias(alias_length)]}
|
||||||
"aliases": [
|
|
||||||
self.random_alias(alias_length),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT", url, request_data, access_token=self.user_tok,
|
"PUT", url, request_data, access_token=self.user_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, expected_code, channel.result)
|
self.assertEqual(channel.code, expected_code, channel.result)
|
||||||
|
|
||||||
def set_alias_via_directory(self, expected_code, alias_length=5):
|
def set_alias_via_directory(self, expected_code, alias_length=5):
|
||||||
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
|
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
|
||||||
data = {
|
data = {"room_id": self.room_id}
|
||||||
"room_id": self.room_id,
|
|
||||||
}
|
|
||||||
request_data = json.dumps(data)
|
request_data = json.dumps(data)
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"PUT", url, request_data, access_token=self.user_tok,
|
"PUT", url, request_data, access_token=self.user_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, expected_code, channel.result)
|
self.assertEqual(channel.code, expected_code, channel.result)
|
||||||
|
|
||||||
def random_alias(self, length):
|
def random_alias(self, length):
|
||||||
return RoomAlias(
|
return RoomAlias(random_string(length), self.hs.hostname).to_string()
|
||||||
random_string(length),
|
|
||||||
self.hs.hostname,
|
|
||||||
).to_string()
|
|
||||||
|
|
||||||
def ensure_user_left_room(self):
|
def ensure_user_left_room(self):
|
||||||
self.ensure_membership("leave")
|
self.ensure_membership("leave")
|
||||||
|
@ -151,17 +140,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
def ensure_membership(self, membership):
|
def ensure_membership(self, membership):
|
||||||
try:
|
try:
|
||||||
if membership == "leave":
|
if membership == "leave":
|
||||||
self.helper.leave(
|
self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||||
room=self.room_id,
|
|
||||||
user=self.user,
|
|
||||||
tok=self.user_tok,
|
|
||||||
)
|
|
||||||
if membership == "join":
|
if membership == "join":
|
||||||
self.helper.join(
|
self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok)
|
||||||
room=self.room_id,
|
|
||||||
user=self.user,
|
|
||||||
tok=self.user_tok,
|
|
||||||
)
|
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
# We don't care whether the leave request didn't return a 200 (e.g.
|
# We don't care whether the leave request didn't return a 200 (e.g.
|
||||||
# if the user isn't already in the room), because we only want to
|
# if the user isn't already in the room), because we only want to
|
||||||
|
|
|
@ -37,10 +37,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit" + str(i),
|
|
||||||
},
|
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
@ -57,14 +54,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
# than 1min.
|
# than 1min.
|
||||||
self.assertTrue(retry_after_ms < 6000)
|
self.assertTrue(retry_after_ms < 6000)
|
||||||
|
|
||||||
self.reactor.advance(retry_after_ms / 1000.)
|
self.reactor.advance(retry_after_ms / 1000.0)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit" + str(i),
|
|
||||||
},
|
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
@ -82,10 +76,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit",
|
|
||||||
},
|
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
@ -102,14 +93,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
# than 1min.
|
# than 1min.
|
||||||
self.assertTrue(retry_after_ms < 6000)
|
self.assertTrue(retry_after_ms < 6000)
|
||||||
|
|
||||||
self.reactor.advance(retry_after_ms / 1000.)
|
self.reactor.advance(retry_after_ms / 1000.0)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit",
|
|
||||||
},
|
|
||||||
"password": "monkey",
|
"password": "monkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
@ -127,10 +115,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
for i in range(0, 6):
|
for i in range(0, 6):
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit",
|
|
||||||
},
|
|
||||||
"password": "notamonkey",
|
"password": "notamonkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
@ -147,14 +132,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
# than 1min.
|
# than 1min.
|
||||||
self.assertTrue(retry_after_ms < 6000)
|
self.assertTrue(retry_after_ms < 6000)
|
||||||
|
|
||||||
self.reactor.advance(retry_after_ms / 1000.)
|
self.reactor.advance(retry_after_ms / 1000.0)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {
|
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||||
"type": "m.id.user",
|
|
||||||
"user": "kermit",
|
|
||||||
},
|
|
||||||
"password": "notamonkey",
|
"password": "notamonkey",
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
|
|
|
@ -199,37 +199,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
|
||||||
def test_in_shared_room(self):
|
def test_in_shared_room(self):
|
||||||
self.ensure_requester_left_room()
|
self.ensure_requester_left_room()
|
||||||
|
|
||||||
self.helper.join(
|
self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
|
||||||
room=self.room_id,
|
|
||||||
user=self.requester,
|
|
||||||
tok=self.requester_tok,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.try_fetch_profile(200, self.requester_tok)
|
self.try_fetch_profile(200, self.requester_tok)
|
||||||
|
|
||||||
def try_fetch_profile(self, expected_code, access_token=None):
|
def try_fetch_profile(self, expected_code, access_token=None):
|
||||||
|
self.request_profile(expected_code, access_token=access_token)
|
||||||
|
|
||||||
self.request_profile(
|
self.request_profile(
|
||||||
expected_code,
|
expected_code, url_suffix="/displayname", access_token=access_token
|
||||||
access_token=access_token
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.request_profile(
|
self.request_profile(
|
||||||
expected_code,
|
expected_code, url_suffix="/avatar_url", access_token=access_token
|
||||||
url_suffix="/displayname",
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.request_profile(
|
|
||||||
expected_code,
|
|
||||||
url_suffix="/avatar_url",
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def request_profile(self, expected_code, url_suffix="", access_token=None):
|
def request_profile(self, expected_code, url_suffix="", access_token=None):
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", self.profile_url + url_suffix, access_token=access_token
|
||||||
self.profile_url + url_suffix,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(channel.code, expected_code, channel.result)
|
self.assertEqual(channel.code, expected_code, channel.result)
|
||||||
|
@ -237,9 +224,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
|
||||||
def ensure_requester_left_room(self):
|
def ensure_requester_left_room(self):
|
||||||
try:
|
try:
|
||||||
self.helper.leave(
|
self.helper.leave(
|
||||||
room=self.room_id,
|
room=self.room_id, user=self.requester, tok=self.requester_tok
|
||||||
user=self.requester,
|
|
||||||
tok=self.requester_tok,
|
|
||||||
)
|
)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
# We don't care whether the leave request didn't return a 200 (e.g.
|
# We don't care whether the leave request didn't return a 200 (e.g.
|
||||||
|
|
|
@ -41,11 +41,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
as_token = "i_am_an_app_service"
|
as_token = "i_am_an_app_service"
|
||||||
|
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
as_token, self.hs.config.server_name,
|
as_token,
|
||||||
|
self.hs.config.server_name,
|
||||||
id="1234",
|
id="1234",
|
||||||
namespaces={
|
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||||
"users": [{"regex": r"@as_user.*", "exclusive": True}],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_datastore().services_cache.append(appservice)
|
self.hs.get_datastore().services_cache.append(appservice)
|
||||||
|
@ -57,10 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
det_data = {
|
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
|
||||||
"user_id": user_id,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
def test_POST_appservice_registration_invalid(self):
|
def test_POST_appservice_registration_invalid(self):
|
||||||
|
@ -128,10 +124,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
det_data = {
|
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
"device_id": "guest_device",
|
|
||||||
}
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
|
|
||||||
|
@ -159,7 +152,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
self.reactor.advance(retry_after_ms / 1000.)
|
self.reactor.advance(retry_after_ms / 1000.0)
|
||||||
|
|
||||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
@ -187,7 +180,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
self.reactor.advance(retry_after_ms / 1000.)
|
self.reactor.advance(retry_after_ms / 1000.0)
|
||||||
|
|
||||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
@ -221,23 +214,19 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||||
# endpoint.
|
# endpoint.
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
b"GET", "/sync", access_token=tok,
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
|
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
b"GET", "/sync", access_token=tok,
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
|
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_manual_renewal(self):
|
def test_manual_renewal(self):
|
||||||
|
@ -253,21 +242,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
admin_tok = self.login("admin", "adminpassword")
|
admin_tok = self.login("admin", "adminpassword")
|
||||||
|
|
||||||
url = "/_matrix/client/unstable/admin/account_validity/validity"
|
url = "/_matrix/client/unstable/admin/account_validity/validity"
|
||||||
params = {
|
params = {"user_id": user_id}
|
||||||
"user_id": user_id,
|
|
||||||
}
|
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", url, request_data, access_token=admin_tok,
|
b"POST", url, request_data, access_token=admin_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||||
# endpoint.
|
# endpoint.
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
b"GET", "/sync", access_token=tok,
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
@ -286,20 +271,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", url, request_data, access_token=admin_tok,
|
b"POST", url, request_data, access_token=admin_tok
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||||
# endpoint.
|
# endpoint.
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
b"GET", "/sync", access_token=tok,
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
|
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -358,10 +341,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
# We need to manually add an email address otherwise the handler will do
|
# We need to manually add an email address otherwise the handler will do
|
||||||
# nothing.
|
# nothing.
|
||||||
now = self.hs.clock.time_msec()
|
now = self.hs.clock.time_msec()
|
||||||
self.get_success(self.store.user_add_threepid(
|
self.get_success(
|
||||||
user_id=user_id, medium="email", address="kermit@example.com",
|
self.store.user_add_threepid(
|
||||||
validated_at=now, added_at=now,
|
user_id=user_id,
|
||||||
))
|
medium="email",
|
||||||
|
address="kermit@example.com",
|
||||||
|
validated_at=now,
|
||||||
|
added_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Move 6 days forward. This should trigger a renewal email to be sent.
|
# Move 6 days forward. This should trigger a renewal email to be sent.
|
||||||
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
|
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
|
||||||
|
@ -379,9 +367,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
# our access token should be denied from now, otherwise they should
|
# our access token should be denied from now, otherwise they should
|
||||||
# succeed.
|
# succeed.
|
||||||
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
|
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(b"GET", "/sync", access_token=tok)
|
||||||
b"GET", "/sync", access_token=tok,
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
|
|
||||||
|
@ -393,13 +379,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||||
# We need to manually add an email address otherwise the handler will do
|
# We need to manually add an email address otherwise the handler will do
|
||||||
# nothing.
|
# nothing.
|
||||||
now = self.hs.clock.time_msec()
|
now = self.hs.clock.time_msec()
|
||||||
self.get_success(self.store.user_add_threepid(
|
self.get_success(
|
||||||
user_id=user_id, medium="email", address="kermit@example.com",
|
self.store.user_add_threepid(
|
||||||
validated_at=now, added_at=now,
|
user_id=user_id,
|
||||||
))
|
medium="email",
|
||||||
|
address="kermit@example.com",
|
||||||
|
validated_at=now,
|
||||||
|
added_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", "/_matrix/client/unstable/account_validity/send_mail",
|
b"POST",
|
||||||
|
"/_matrix/client/unstable/account_validity/send_mail",
|
||||||
access_token=tok,
|
access_token=tok,
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
|
@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
|
||||||
b'inline; filename="aze%20rty"': u"aze%20rty",
|
b'inline; filename="aze%20rty"': u"aze%20rty",
|
||||||
b'inline; filename="aze\"rty"': u'aze"rty',
|
b'inline; filename="aze\"rty"': u'aze"rty',
|
||||||
b'inline; filename="azer;ty"': u"azer;ty",
|
b'inline; filename="azer;ty"': u"azer;ty",
|
||||||
|
|
||||||
b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
|
b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
|
||||||
}
|
}
|
||||||
|
|
||||||
def tests(self):
|
def tests(self):
|
||||||
for hdr, expected in self.TEST_CASES.items():
|
for hdr, expected in self.TEST_CASES.items():
|
||||||
res = get_filename_from_headers(
|
res = get_filename_from_headers({b'Content-Disposition': [hdr]})
|
||||||
{
|
|
||||||
b'Content-Disposition': [hdr],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res, expected,
|
res,
|
||||||
"expected output for %s to be %s but was %s" % (
|
expected,
|
||||||
hdr, expected, res,
|
"expected output for %s to be %s but was %s" % (hdr, expected, res),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,27 +31,24 @@ class WellKnownTests(unittest.HomeserverTestCase):
|
||||||
self.hs.config.default_identity_server = "https://testis"
|
self.hs.config.default_identity_server = "https://testis"
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", "/.well-known/matrix/client", shorthand=False
|
||||||
"/.well-known/matrix/client",
|
|
||||||
shorthand=False,
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
self.assertEqual(request.code, 200)
|
self.assertEqual(request.code, 200)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body, {
|
channel.json_body,
|
||||||
|
{
|
||||||
"m.homeserver": {"base_url": "https://tesths"},
|
"m.homeserver": {"base_url": "https://tesths"},
|
||||||
"m.identity_server": {"base_url": "https://testis"},
|
"m.identity_server": {"base_url": "https://testis"},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_well_known_no_public_baseurl(self):
|
def test_well_known_no_public_baseurl(self):
|
||||||
self.hs.config.public_baseurl = None
|
self.hs.config.public_baseurl = None
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"GET",
|
"GET", "/.well-known/matrix/client", shorthand=False
|
||||||
"/.well-known/matrix/client",
|
|
||||||
shorthand=False,
|
|
||||||
)
|
)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,8 @@ def make_request(
|
||||||
|
|
||||||
if federation_auth_origin is not None:
|
if federation_auth_origin is not None:
|
||||||
req.requestHeaders.addRawHeader(
|
req.requestHeaders.addRawHeader(
|
||||||
b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
|
b"Authorization",
|
||||||
|
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
|
||||||
)
|
)
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
|
|
|
@ -27,7 +27,6 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
hs_config = self.default_config("test")
|
hs_config = self.default_config("test")
|
||||||
hs_config.server_notices_mxid = "@server:test"
|
hs_config.server_notices_mxid = "@server:test"
|
||||||
|
|
|
@ -50,6 +50,7 @@ class FakeEvent(object):
|
||||||
refer to events. The event_id has node_id as localpart and example.com
|
refer to events. The event_id has node_id as localpart and example.com
|
||||||
as domain.
|
as domain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id, sender, type, state_key, content):
|
def __init__(self, id, sender, type, state_key, content):
|
||||||
self.node_id = id
|
self.node_id = id
|
||||||
self.event_id = EventID(id, "example.com").to_string()
|
self.event_id = EventID(id, "example.com").to_string()
|
||||||
|
@ -142,24 +143,14 @@ INITIAL_EVENTS = [
|
||||||
content=MEMBERSHIP_CONTENT_JOIN,
|
content=MEMBERSHIP_CONTENT_JOIN,
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="START",
|
id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={}
|
||||||
sender=ZARA,
|
|
||||||
type=EventTypes.Message,
|
|
||||||
state_key=None,
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="END",
|
id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={}
|
||||||
sender=ZARA,
|
|
||||||
type=EventTypes.Message,
|
|
||||||
state_key=None,
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
INITIAL_EDGES = [
|
INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
|
||||||
"START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class StateTestCase(unittest.TestCase):
|
class StateTestCase(unittest.TestCase):
|
||||||
|
@ -170,12 +161,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key="",
|
state_key="",
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="MA",
|
id="MA",
|
||||||
|
@ -196,19 +182,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
sender=BOB,
|
sender=BOB,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]]
|
||||||
["END", "MB", "MA", "PA", "START"],
|
|
||||||
["END", "PB", "PA"],
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_state_ids = ["PA", "MA", "MB"]
|
expected_state_ids = ["PA", "MA", "MB"]
|
||||||
|
|
||||||
|
@ -232,10 +210,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [["END", "JR", "START"], ["END", "ME", "START"]]
|
||||||
["END", "JR", "START"],
|
|
||||||
["END", "ME", "START"],
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_state_ids = ["JR"]
|
expected_state_ids = ["JR"]
|
||||||
|
|
||||||
|
@ -248,45 +223,25 @@ class StateTestCase(unittest.TestCase):
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key="",
|
state_key="",
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PB",
|
id="PB",
|
||||||
sender=BOB,
|
sender=BOB,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
CHARLIE: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PC",
|
id="PC",
|
||||||
sender=CHARLIE,
|
sender=CHARLIE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
CHARLIE: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]]
|
||||||
["END", "PC", "PB", "PA", "START"],
|
|
||||||
["END", "PA"],
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_state_ids = ["PC"]
|
expected_state_ids = ["PC"]
|
||||||
|
|
||||||
|
@ -295,68 +250,38 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_topic_basic(self):
|
def test_topic_basic(self):
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1",
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA1",
|
id="PA1",
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T2",
|
id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA2",
|
id="PA2",
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 0}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PB",
|
id="PB",
|
||||||
sender=BOB,
|
sender=BOB,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T3",
|
id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=BOB,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]]
|
||||||
["END", "PA2", "T2", "PA1", "T1", "START"],
|
|
||||||
["END", "T3", "PB", "PA1"],
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_state_ids = ["PA2", "T2"]
|
expected_state_ids = ["PA2", "T2"]
|
||||||
|
|
||||||
|
@ -365,30 +290,17 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_topic_reset(self):
|
def test_topic_reset(self):
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1",
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA",
|
id="PA",
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T2",
|
id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=BOB,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="MB",
|
id="MB",
|
||||||
|
@ -399,10 +311,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]]
|
||||||
["END", "MB", "T2", "PA", "T1", "START"],
|
|
||||||
["END", "T1"],
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_state_ids = ["T1", "MB", "PA"]
|
expected_state_ids = ["T1", "MB", "PA"]
|
||||||
|
|
||||||
|
@ -411,61 +320,34 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_topic(self):
|
def test_topic(self):
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1",
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA1",
|
id="PA1",
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T2",
|
id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA2",
|
id="PA2",
|
||||||
sender=ALICE,
|
sender=ALICE,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 0}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PB",
|
id="PB",
|
||||||
sender=BOB,
|
sender=BOB,
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key='',
|
state_key='',
|
||||||
content={
|
content={"users": {ALICE: 100, BOB: 50}},
|
||||||
"users": {
|
|
||||||
ALICE: 100,
|
|
||||||
BOB: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T3",
|
id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=BOB,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="MZ1",
|
id="MZ1",
|
||||||
|
@ -475,11 +357,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
content={},
|
content={},
|
||||||
),
|
),
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T4",
|
id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
sender=ALICE,
|
|
||||||
type=EventTypes.Topic,
|
|
||||||
state_key="",
|
|
||||||
content={},
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -587,13 +465,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
class LexicographicalTestCase(unittest.TestCase):
|
class LexicographicalTestCase(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
graph = {
|
graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
|
||||||
"l": {"o"},
|
|
||||||
"m": {"n", "o"},
|
|
||||||
"n": {"o"},
|
|
||||||
"o": set(),
|
|
||||||
"p": {"o"},
|
|
||||||
}
|
|
||||||
|
|
||||||
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
|
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
|
||||||
|
|
||||||
|
@ -680,7 +552,13 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.expected_combined_state = {
|
self.expected_combined_state = {
|
||||||
(e.type, e.state_key): e.event_id
|
(e.type, e.state_key): e.event_id
|
||||||
for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
|
for e in [
|
||||||
|
create_event,
|
||||||
|
alice_member,
|
||||||
|
join_rules,
|
||||||
|
bob_member,
|
||||||
|
charlie_member,
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_event_map_none(self):
|
def test_event_map_none(self):
|
||||||
|
@ -720,11 +598,7 @@ class TestStateResolutionStore(object):
|
||||||
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
|
||||||
eid: self.event_map[eid]
|
|
||||||
for eid in event_ids
|
|
||||||
if eid in self.event_map
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_auth_chain(self, event_ids):
|
def get_auth_chain(self, event_ids):
|
||||||
"""Gets the full auth chain for a set of events (including rejected
|
"""Gets the full auth chain for a set of events (including rejected
|
||||||
|
|
|
@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver
|
||||||
class BackgroundUpdateTestCase(unittest.TestCase):
|
class BackgroundUpdateTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(self.addCleanup)
|
||||||
self.addCleanup
|
|
||||||
)
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
|
|
@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
fake_engine = Mock(wraps=engine)
|
fake_engine = Mock(wraps=engine)
|
||||||
fake_engine.can_native_upsert = False
|
fake_engine.can_native_upsert = False
|
||||||
hs = TestHomeServer(
|
hs = TestHomeServer(
|
||||||
"test",
|
"test", db_pool=self.db_pool, config=config, database_engine=fake_engine
|
||||||
db_pool=self.db_pool,
|
|
||||||
config=config,
|
|
||||||
database_engine=fake_engine,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore = SQLBaseStore(None, hs)
|
self.datastore = SQLBaseStore(None, hs)
|
||||||
|
|
|
@ -20,7 +20,6 @@ import tests.utils
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
||||||
|
|
|
@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.register(user_id=user1, token="123", password_hash=None)
|
self.store.register(user_id=user1, token="123", password_hash=None)
|
||||||
self.store.register(user_id=user2, token="456", password_hash=None)
|
self.store.register(user_id=user2, token="456", password_hash=None)
|
||||||
self.store.register(
|
self.store.register(
|
||||||
user_id=user3, token="789",
|
user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT
|
||||||
password_hash=None, user_type=UserTypes.SUPPORT
|
|
||||||
)
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
|
@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
def test_populate_monthly_users_should_update(self):
|
def test_populate_monthly_users_should_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock()
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(
|
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||||
return_value=defer.succeed(False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(None)
|
return_value=defer.succeed(None)
|
||||||
|
@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
def test_populate_monthly_users_should_not_update(self):
|
def test_populate_monthly_users_should_not_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock()
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(
|
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||||
return_value=defer.succeed(False)
|
|
||||||
)
|
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(
|
return_value=defer.succeed(self.hs.get_clock().time_msec())
|
||||||
self.hs.get_clock().time_msec()
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.store.populate_monthly_active_users('user_id')
|
self.store.populate_monthly_active_users('user_id')
|
||||||
self.pump()
|
self.pump()
|
||||||
|
@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||||
user_id=support_user_id,
|
user_id=support_user_id,
|
||||||
token="123",
|
token="123",
|
||||||
password_hash=None,
|
password_hash=None,
|
||||||
user_type=UserTypes.SUPPORT
|
user_type=UserTypes.SUPPORT,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user(support_user_id)
|
self.store.upsert_monthly_active_user(support_user_id)
|
||||||
|
|
|
@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
"state_key": user.to_string(),
|
"state_key": user.to_string(),
|
||||||
"room_id": room.to_string(),
|
"room_id": room.to_string(),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
"state_key": user.to_string(),
|
"state_key": user.to_string(),
|
||||||
"room_id": room.to_string(),
|
"room_id": room.to_string(),
|
||||||
"content": {"body": body, "msgtype": u"message"},
|
"content": {"body": body, "msgtype": u"message"},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
"room_id": room.to_string(),
|
"room_id": room.to_string(),
|
||||||
"content": {"reason": reason},
|
"content": {"reason": reason},
|
||||||
"redacts": event_id,
|
"redacts": event_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
|
|
@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||||
user_id=SUPPORT_USER,
|
user_id=SUPPORT_USER,
|
||||||
token="456",
|
token="456",
|
||||||
password_hash=None,
|
password_hash=None,
|
||||||
user_type=UserTypes.SUPPORT
|
user_type=UserTypes.SUPPORT,
|
||||||
)
|
)
|
||||||
res = yield self.store.is_support_user(SUPPORT_USER)
|
res = yield self.store.is_support_user(SUPPORT_USER)
|
||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
|
|
|
@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||||
"state_key": user.to_string(),
|
"state_key": user.to_string(),
|
||||||
"room_id": room.to_string(),
|
"room_id": room.to_string(),
|
||||||
"content": {"membership": membership},
|
"content": {"membership": membership},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
|
|
@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateStoreTestCase(tests.unittest.TestCase):
|
class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
|
||||||
|
@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
"state_key": state_key,
|
"state_key": state_key,
|
||||||
"room_id": room.to_string(),
|
"room_id": room.to_string(),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
|
state_group_map = yield self.store.get_state_groups_ids(
|
||||||
|
self.room, [e2.event_id]
|
||||||
|
)
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_map = list(state_group_map.values())[0]
|
state_map = list(state_group_map.values())[0]
|
||||||
self.assertDictEqual(
|
self.assertDictEqual(
|
||||||
state_map,
|
state_map,
|
||||||
{
|
{(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id},
|
||||||
(EventTypes.Create, ''): e1.event_id,
|
|
||||||
(EventTypes.Name, ''): e2.event_id,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = yield self.store.get_state_groups(
|
state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
|
||||||
self.room, [e2.event_id])
|
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_list = list(state_group_map.values())[0]
|
state_list = list(state_group_map.values())[0]
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
|
||||||
{ev.event_id for ev in state_list},
|
|
||||||
{e1.event_id, e2.event_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_state_for_event(self):
|
def test_get_state_for_event(self):
|
||||||
|
@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check we get the full state as of the final event
|
# check we get the full state as of the final event
|
||||||
state = yield self.store.get_state_for_event(
|
state = yield self.store.get_state_for_event(e5.event_id)
|
||||||
e5.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(e4)
|
self.assertIsNotNone(e4)
|
||||||
|
|
||||||
|
@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {self.u_alice.to_string()}},
|
types={EventTypes.Member: {self.u_alice.to_string()}},
|
||||||
include_others=True,
|
include_others=True,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual(
|
self.assertStateMapEqual(
|
||||||
|
@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
# check that we can grab everything except members
|
# check that we can grab everything except members
|
||||||
state = yield self.store.get_state_for_event(
|
state = yield self.store.get_state_for_event(
|
||||||
e5.event_id, state_filter=StateFilter(
|
e5.event_id,
|
||||||
types={EventTypes.Member: set()},
|
state_filter=StateFilter(
|
||||||
include_others=True,
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
# test _get_state_for_group_using_cache correctly filters out members
|
# test _get_state_for_group_using_cache correctly filters out members
|
||||||
# with types=[]
|
# with types=[]
|
||||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||||
self.store._state_group_cache, group,
|
self.store._state_group_cache,
|
||||||
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: set()},
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: set()},
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_cache,
|
self.store._state_group_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: None},
|
types={EventTypes.Member: None}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: None},
|
types={EventTypes.Member: None}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_cache,
|
self.store._state_group_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||||
include_others=False,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
# with types=[]
|
# with types=[]
|
||||||
room_id = self.room.to_string()
|
room_id = self.room.to_string()
|
||||||
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
(state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
|
||||||
self.store._state_group_cache, group,
|
self.store._state_group_cache,
|
||||||
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: set()},
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: set()},
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_cache,
|
self.store._state_group_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: None},
|
types={EventTypes.Member: None}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: None},
|
types={EventTypes.Member: None}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_cache,
|
self.store._state_group_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||||
include_others=True,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_cache,
|
self.store._state_group_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||||
include_others=False,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.store._state_group_members_cache,
|
self.store._state_group_members_cache,
|
||||||
group,
|
group,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {e5.state_key}},
|
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||||
include_others=False,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||||
yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
|
yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
|
||||||
yield self.store.update_profile_in_user_dir(BOB, "bob", None)
|
yield self.store.update_profile_in_user_dir(BOB, "bob", None)
|
||||||
yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
|
yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
|
||||||
yield self.store.add_users_in_public_rooms(
|
yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
|
||||||
"!room:id", (ALICE, BOB)
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_search_user_dir(self):
|
def test_search_user_dir(self):
|
||||||
|
|
|
@ -37,7 +37,9 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# creator should be able to send state
|
# creator should be able to send state
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.V1.identifier, _random_state_event(creator), auth_events,
|
RoomVersions.V1.identifier,
|
||||||
|
_random_state_event(creator),
|
||||||
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,7 +84,9 @@ class EventAuthTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# king should be able to send state
|
# king should be able to send state
|
||||||
event_auth.check(
|
event_auth.check(
|
||||||
RoomVersions.V1.identifier, _random_state_event(king), auth_events,
|
RoomVersions.V1.identifier,
|
||||||
|
_random_state_event(king),
|
||||||
|
auth_events,
|
||||||
do_sig_check=False,
|
do_sig_check=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import maybeDeferred, succeed
|
from twisted.internet.defer import maybeDeferred, succeed
|
||||||
|
|
|
@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
|
|
||||||
self.hs = self.setup_test_homeserver(
|
self.hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red", http_client=None, federation_client=Mock()
|
||||||
http_client=None,
|
|
||||||
federation_client=Mock(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
|
@ -210,9 +208,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
def do_sync_for_user(self, token):
|
def do_sync_for_user(self, token):
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request("GET", "/sync", access_token=token)
|
||||||
"GET", "/sync", access_token=token
|
|
||||||
)
|
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
if channel.code != 200:
|
if channel.code != 200:
|
||||||
|
|
|
@ -44,9 +44,7 @@ def get_sample_labels_value(sample):
|
||||||
class TestMauLimit(unittest.TestCase):
|
class TestMauLimit(unittest.TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
gauge = InFlightGauge(
|
gauge = InFlightGauge(
|
||||||
"test1", "",
|
"test1", "", labels=["test_label"], sub_metrics=["foo", "bar"]
|
||||||
labels=["test_label"],
|
|
||||||
sub_metrics=["foo", "bar"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle1(metrics):
|
def handle1(metrics):
|
||||||
|
@ -59,37 +57,49 @@ class TestMauLimit(unittest.TestCase):
|
||||||
|
|
||||||
gauge.register(("key1",), handle1)
|
gauge.register(("key1",), handle1)
|
||||||
|
|
||||||
self.assert_dict({
|
self.assert_dict(
|
||||||
|
{
|
||||||
"test1_total": {("key1",): 1},
|
"test1_total": {("key1",): 1},
|
||||||
"test1_foo": {("key1",): 2},
|
"test1_foo": {("key1",): 2},
|
||||||
"test1_bar": {("key1",): 5},
|
"test1_bar": {("key1",): 5},
|
||||||
}, self.get_metrics_from_gauge(gauge))
|
},
|
||||||
|
self.get_metrics_from_gauge(gauge),
|
||||||
|
)
|
||||||
|
|
||||||
gauge.unregister(("key1",), handle1)
|
gauge.unregister(("key1",), handle1)
|
||||||
|
|
||||||
self.assert_dict({
|
self.assert_dict(
|
||||||
|
{
|
||||||
"test1_total": {("key1",): 0},
|
"test1_total": {("key1",): 0},
|
||||||
"test1_foo": {("key1",): 0},
|
"test1_foo": {("key1",): 0},
|
||||||
"test1_bar": {("key1",): 0},
|
"test1_bar": {("key1",): 0},
|
||||||
}, self.get_metrics_from_gauge(gauge))
|
},
|
||||||
|
self.get_metrics_from_gauge(gauge),
|
||||||
|
)
|
||||||
|
|
||||||
gauge.register(("key1",), handle1)
|
gauge.register(("key1",), handle1)
|
||||||
gauge.register(("key2",), handle2)
|
gauge.register(("key2",), handle2)
|
||||||
|
|
||||||
self.assert_dict({
|
self.assert_dict(
|
||||||
|
{
|
||||||
"test1_total": {("key1",): 1, ("key2",): 1},
|
"test1_total": {("key1",): 1, ("key2",): 1},
|
||||||
"test1_foo": {("key1",): 2, ("key2",): 3},
|
"test1_foo": {("key1",): 2, ("key2",): 3},
|
||||||
"test1_bar": {("key1",): 5, ("key2",): 7},
|
"test1_bar": {("key1",): 5, ("key2",): 7},
|
||||||
}, self.get_metrics_from_gauge(gauge))
|
},
|
||||||
|
self.get_metrics_from_gauge(gauge),
|
||||||
|
)
|
||||||
|
|
||||||
gauge.unregister(("key2",), handle2)
|
gauge.unregister(("key2",), handle2)
|
||||||
gauge.register(("key1",), handle2)
|
gauge.register(("key1",), handle2)
|
||||||
|
|
||||||
self.assert_dict({
|
self.assert_dict(
|
||||||
|
{
|
||||||
"test1_total": {("key1",): 2, ("key2",): 0},
|
"test1_total": {("key1",): 2, ("key2",): 0},
|
||||||
"test1_foo": {("key1",): 5, ("key2",): 0},
|
"test1_foo": {("key1",): 5, ("key2",): 0},
|
||||||
"test1_bar": {("key1",): 7, ("key2",): 0},
|
"test1_bar": {("key1",): 7, ("key2",): 0},
|
||||||
}, self.get_metrics_from_gauge(gauge))
|
},
|
||||||
|
self.get_metrics_from_gauge(gauge),
|
||||||
|
)
|
||||||
|
|
||||||
def get_metrics_from_gauge(self, gauge):
|
def get_metrics_from_gauge(self, gauge):
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
@ -69,10 +69,10 @@ class TermsTestCase(unittest.HomeserverTestCase):
|
||||||
"name": "My Cool Privacy Policy",
|
"name": "My Cool Privacy Policy",
|
||||||
"url": "https://example.org/_matrix/consent?v=1.0",
|
"url": "https://example.org/_matrix/consent?v=1.0",
|
||||||
},
|
},
|
||||||
"version": "1.0"
|
"version": "1.0",
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
self.assertIsInstance(channel.json_body["params"], dict)
|
self.assertIsInstance(channel.json_body["params"], dict)
|
||||||
self.assertDictContainsSubset(channel.json_body["params"], expected_params)
|
self.assertDictContainsSubset(channel.json_body["params"], expected_params)
|
||||||
|
|
|
@ -94,8 +94,7 @@ class MapUsernameTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def testSymbols(self):
|
def testSymbols(self):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
map_username_to_mxid_localpart("test=$?_1234"),
|
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
|
||||||
"test=3d=24=3f_1234",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def testLeadingUnderscore(self):
|
def testLeadingUnderscore(self):
|
||||||
|
@ -105,6 +104,5 @@ class MapUsernameTestCase(unittest.TestCase):
|
||||||
# this should work with either a unicode or a bytes
|
# this should work with either a unicode or a bytes
|
||||||
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
|
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
|
map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast"
|
||||||
"t=c3=aast",
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.util.logcontext import LoggingContextFilter
|
||||||
|
|
||||||
class ToTwistedHandler(logging.Handler):
|
class ToTwistedHandler(logging.Handler):
|
||||||
"""logging handler which sends the logs to the twisted log"""
|
"""logging handler which sends the logs to the twisted log"""
|
||||||
|
|
||||||
tx_log = twisted.logger.Logger()
|
tx_log = twisted.logger.Logger()
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
|
@ -41,7 +42,8 @@ def setup_logging():
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
log_format = (
|
log_format = (
|
||||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
|
"%(asctime)s - %(name)s - %(lineno)d - "
|
||||||
|
"%(levelname)s - %(request)s - %(message)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = ToTwistedHandler()
|
handler = ToTwistedHandler()
|
||||||
|
|
|
@ -132,7 +132,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
"state_key": "",
|
"state_key": "",
|
||||||
"room_id": TEST_ROOM_ID,
|
"room_id": TEST_ROOM_ID,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -153,7 +153,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
"state_key": user_id,
|
"state_key": user_id,
|
||||||
"room_id": TEST_ROOM_ID,
|
"room_id": TEST_ROOM_ID,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
@ -174,7 +174,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
"sender": user_id,
|
"sender": user_id,
|
||||||
"room_id": TEST_ROOM_ID,
|
"room_id": TEST_ROOM_ID,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||||
|
|
|
@ -84,9 +84,8 @@ class TestCase(unittest.TestCase):
|
||||||
# all future bets are off.
|
# all future bets are off.
|
||||||
if LoggingContext.current_context() is not LoggingContext.sentinel:
|
if LoggingContext.current_context() is not LoggingContext.sentinel:
|
||||||
self.fail(
|
self.fail(
|
||||||
"Test starting with non-sentinel logging context %s" % (
|
"Test starting with non-sentinel logging context %s"
|
||||||
LoggingContext.current_context(),
|
% (LoggingContext.current_context(),)
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
old_level = logging.getLogger().level
|
old_level = logging.getLogger().level
|
||||||
|
@ -300,7 +299,13 @@ class HomeserverTestCase(TestCase):
|
||||||
content = json.dumps(content).encode('utf8')
|
content = json.dumps(content).encode('utf8')
|
||||||
|
|
||||||
return make_request(
|
return make_request(
|
||||||
self.reactor, method, path, content, access_token, request, shorthand,
|
self.reactor,
|
||||||
|
method,
|
||||||
|
path,
|
||||||
|
content,
|
||||||
|
access_token,
|
||||||
|
request,
|
||||||
|
shorthand,
|
||||||
federation_auth_origin,
|
federation_auth_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ class TimeoutDeferredTest(TestCase):
|
||||||
self.clock.pump((1.0,))
|
self.clock.pump((1.0,))
|
||||||
|
|
||||||
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
|
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
|
||||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||||
|
|
||||||
def test_times_out_when_canceller_throws(self):
|
def test_times_out_when_canceller_throws(self):
|
||||||
"""Test that we have successfully worked around
|
"""Test that we have successfully worked around
|
||||||
|
@ -61,7 +61,7 @@ class TimeoutDeferredTest(TestCase):
|
||||||
|
|
||||||
self.clock.pump((1.0,))
|
self.clock.pump((1.0,))
|
||||||
|
|
||||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||||
|
|
||||||
def test_logcontext_is_preserved_on_cancellation(self):
|
def test_logcontext_is_preserved_on_cancellation(self):
|
||||||
blocking_was_cancelled = [False]
|
blocking_was_cancelled = [False]
|
||||||
|
@ -80,10 +80,10 @@ class TimeoutDeferredTest(TestCase):
|
||||||
# the errbacks should be run in the test logcontext
|
# the errbacks should be run in the test logcontext
|
||||||
def errback(res, deferred_name):
|
def errback(res, deferred_name):
|
||||||
self.assertIs(
|
self.assertIs(
|
||||||
LoggingContext.current_context(), context_one,
|
LoggingContext.current_context(),
|
||||||
"errback %s run in unexpected logcontext %s" % (
|
context_one,
|
||||||
deferred_name, LoggingContext.current_context(),
|
"errback %s run in unexpected logcontext %s"
|
||||||
)
|
% (deferred_name, LoggingContext.current_context()),
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -97,8 +97,7 @@ class TimeoutDeferredTest(TestCase):
|
||||||
self.clock.pump((1.0,))
|
self.clock.pump((1.0,))
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
blocking_was_cancelled[0],
|
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
|
||||||
"non-completing deferred was not cancelled",
|
|
||||||
)
|
)
|
||||||
self.failureResultOf(timing_out_d, defer.TimeoutError, )
|
self.failureResultOf(timing_out_d, defer.TimeoutError)
|
||||||
self.assertIs(LoggingContext.current_context(), context_one)
|
self.assertIs(LoggingContext.current_context(), context_one)
|
||||||
|
|
|
@ -68,7 +68,9 @@ def setupdb():
|
||||||
|
|
||||||
# connect to postgres to create the base database.
|
# connect to postgres to create the base database.
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = db_engine.module.connect(
|
||||||
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
|
user=POSTGRES_USER,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||||
)
|
)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
|
@ -94,7 +96,9 @@ def setupdb():
|
||||||
|
|
||||||
def _cleanup():
|
def _cleanup():
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = db_engine.module.connect(
|
||||||
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
|
user=POSTGRES_USER,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||||
)
|
)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
|
@ -114,7 +118,6 @@ def default_config(name):
|
||||||
"server_name": name,
|
"server_name": name,
|
||||||
"media_store_path": "media",
|
"media_store_path": "media",
|
||||||
"uploads_path": "uploads",
|
"uploads_path": "uploads",
|
||||||
|
|
||||||
# the test signing key is just an arbitrary ed25519 key to keep the config
|
# the test signing key is just an arbitrary ed25519 key to keep the config
|
||||||
# parser happy
|
# parser happy
|
||||||
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
|
"signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
|
||||||
|
|
Loading…
Reference in a new issue