forked from MirrorHub/synapse
Properly typecheck tests.api (#14983)
This commit is contained in:
parent
b2d97bac09
commit
6e6edea6c1
7 changed files with 141 additions and 111 deletions
1
changelog.d/14983.misc
Normal file
1
changelog.d/14983.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
4
mypy.ini
4
mypy.ini
|
@ -32,7 +32,6 @@ exclude = (?x)
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/schema/
|
|synapse/storage/schema/
|
||||||
|
|
||||||
|tests/api/test_auth.py
|
|
||||||
|tests/appservice/test_scheduler.py
|
|tests/appservice/test_scheduler.py
|
||||||
|tests/federation/test_federation_catch_up.py
|
|tests/federation/test_federation_catch_up.py
|
||||||
|tests/federation/test_federation_sender.py
|
|tests/federation/test_federation_sender.py
|
||||||
|
@ -73,6 +72,9 @@ disallow_untyped_defs = False
|
||||||
[mypy-tests.*]
|
[mypy-tests.*]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
[mypy-tests.api.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.app.*]
|
[mypy-tests.app.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -252,9 +252,9 @@ class FilterCollection:
|
||||||
return self._room_timeline_filter.unread_thread_notifications
|
return self._room_timeline_filter.unread_thread_notifications
|
||||||
|
|
||||||
async def filter_presence(
|
async def filter_presence(
|
||||||
self, events: Iterable[UserPresenceState]
|
self, presence_states: Iterable[UserPresenceState]
|
||||||
) -> List[UserPresenceState]:
|
) -> List[UserPresenceState]:
|
||||||
return await self._presence_filter.filter(events)
|
return await self._presence_filter.filter(presence_states)
|
||||||
|
|
||||||
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||||
return await self._account_data.filter(events)
|
return await self._account_data.filter(events)
|
||||||
|
|
|
@ -31,7 +31,7 @@ from synapse.api.errors import (
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -41,10 +41,12 @@ from tests.utils import mock_getRawHeaders
|
||||||
|
|
||||||
|
|
||||||
class AuthTestCase(unittest.HomeserverTestCase):
|
class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = Mock()
|
self.store = Mock()
|
||||||
|
|
||||||
hs.datastores.main = self.store
|
# type-ignore: datastores is None until hs.setup() is called---but it'll
|
||||||
|
# have been called by the HomeserverTestCase machinery.
|
||||||
|
hs.datastores.main = self.store # type: ignore[union-attr]
|
||||||
hs.get_auth_handler().store = self.store
|
hs.get_auth_handler().store = self.store
|
||||||
self.auth = Auth(hs)
|
self.auth = Auth(hs)
|
||||||
|
|
||||||
|
@ -61,7 +63,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = simple_async_mock(None)
|
||||||
self.store.is_support_user = simple_async_mock(False)
|
self.store.is_support_user = simple_async_mock(False)
|
||||||
|
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self) -> None:
|
||||||
user_info = TokenLookupResult(
|
user_info = TokenLookupResult(
|
||||||
user_id=self.test_user, token_id=5, device_id="device"
|
user_id=self.test_user, token_id=5, device_id="device"
|
||||||
)
|
)
|
||||||
|
@ -74,7 +76,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
self.assertEqual(requester.user.to_string(), self.test_user)
|
self.assertEqual(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
@ -86,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self) -> None:
|
||||||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
||||||
|
|
||||||
|
@ -98,7 +100,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token(self):
|
def test_get_user_by_req_appservice_valid_token(self) -> None:
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
)
|
)
|
||||||
|
@ -112,7 +114,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
self.assertEqual(requester.user.to_string(), self.test_user)
|
self.assertEqual(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_good_ip(self):
|
def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
|
@ -131,7 +133,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||||
self.assertEqual(requester.user.to_string(), self.test_user)
|
self.assertEqual(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
|
def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
|
||||||
from netaddr import IPSet
|
from netaddr import IPSet
|
||||||
|
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
|
@ -153,7 +155,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self):
|
def test_get_user_by_req_appservice_bad_token(self) -> None:
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_missing_token(self):
|
def test_get_user_by_req_appservice_missing_token(self) -> None:
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
@ -179,7 +181,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
|
||||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
|
@ -200,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
requester.user.to_string(), masquerading_user_id.decode("utf8")
|
requester.user.to_string(), masquerading_user_id.decode("utf8")
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
|
||||||
masquerading_user_id = b"@doppelganger:matrix.org"
|
masquerading_user_id = b"@doppelganger:matrix.org"
|
||||||
app_service = Mock(
|
app_service = Mock(
|
||||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
|
@ -217,7 +219,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_failure(self.auth.get_user_by_req(request), AuthError)
|
self.get_failure(self.auth.get_user_by_req(request), AuthError)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
|
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
|
||||||
def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
|
def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that when an application service passes the device_id URL parameter
|
Tests that when an application service passes the device_id URL parameter
|
||||||
with the ID of a valid device for the user in question,
|
with the ID of a valid device for the user in question,
|
||||||
|
@ -249,7 +251,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
|
self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
|
@override_config({"experimental_features": {"msc3202_device_masquerading": True}})
|
||||||
def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
|
def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that when an application service passes the device_id URL parameter
|
Tests that when an application service passes the device_id URL parameter
|
||||||
with an ID that is not a valid device ID for the user in question,
|
with an ID that is not a valid device ID for the user in question,
|
||||||
|
@ -279,7 +281,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(failure.value.code, 400)
|
self.assertEqual(failure.value.code, 400)
|
||||||
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
|
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
|
||||||
|
|
||||||
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
|
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(
|
self.store.get_user_by_access_token = simple_async_mock(
|
||||||
TokenLookupResult(
|
TokenLookupResult(
|
||||||
user_id="@baldrick:matrix.org",
|
user_id="@baldrick:matrix.org",
|
||||||
|
@ -298,7 +300,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
self.store.insert_client_ip.assert_called_once()
|
self.store.insert_client_ip.assert_called_once()
|
||||||
|
|
||||||
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
|
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
|
||||||
self.auth._track_puppeted_user_ips = True
|
self.auth._track_puppeted_user_ips = True
|
||||||
self.store.get_user_by_access_token = simple_async_mock(
|
self.store.get_user_by_access_token = simple_async_mock(
|
||||||
TokenLookupResult(
|
TokenLookupResult(
|
||||||
|
@ -318,7 +320,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.auth.get_user_by_req(request))
|
self.get_success(self.auth.get_user_by_req(request))
|
||||||
self.assertEqual(self.store.insert_client_ip.call_count, 2)
|
self.assertEqual(self.store.insert_client_ip.call_count, 2)
|
||||||
|
|
||||||
def test_get_user_from_macaroon(self):
|
def test_get_user_from_macaroon(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
|
@ -336,7 +338,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
|
self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self) -> None:
|
||||||
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
|
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||||
|
|
||||||
|
@ -357,7 +359,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(user_info.is_guest)
|
self.assertTrue(user_info.is_guest)
|
||||||
self.store.get_user_by_id.assert_called_with(user_id)
|
self.store.get_user_by_id.assert_called_with(user_id)
|
||||||
|
|
||||||
def test_blocking_mau(self):
|
def test_blocking_mau(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = False
|
self.auth_blocking._limit_usage_by_mau = False
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
lots_of_users = 100
|
lots_of_users = 100
|
||||||
|
@ -381,7 +383,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
|
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
|
||||||
self.get_success(self.auth_blocking.check_auth_blocking())
|
self.get_success(self.auth_blocking.check_auth_blocking())
|
||||||
|
|
||||||
def test_blocking_mau__depending_on_user_type(self):
|
def test_blocking_mau__depending_on_user_type(self) -> None:
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
|
|
||||||
|
@ -400,7 +402,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
# Real users not allowed
|
# Real users not allowed
|
||||||
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
||||||
|
|
||||||
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
|
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._track_appservice_user_ips = False
|
self.auth_blocking._track_appservice_user_ips = False
|
||||||
|
@ -418,7 +422,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
sender="@appservice:sender",
|
sender="@appservice:sender",
|
||||||
)
|
)
|
||||||
requester = Requester(
|
requester = Requester(
|
||||||
user="@appservice:server",
|
user=UserID.from_string("@appservice:server"),
|
||||||
access_token_id=None,
|
access_token_id=None,
|
||||||
device_id="FOOBAR",
|
device_id="FOOBAR",
|
||||||
is_guest=False,
|
is_guest=False,
|
||||||
|
@ -428,7 +432,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
|
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
|
||||||
|
|
||||||
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
|
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._track_appservice_user_ips = True
|
self.auth_blocking._track_appservice_user_ips = True
|
||||||
|
@ -446,7 +452,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
sender="@appservice:sender",
|
sender="@appservice:sender",
|
||||||
)
|
)
|
||||||
requester = Requester(
|
requester = Requester(
|
||||||
user="@appservice:server",
|
user=UserID.from_string("@appservice:server"),
|
||||||
access_token_id=None,
|
access_token_id=None,
|
||||||
device_id="FOOBAR",
|
device_id="FOOBAR",
|
||||||
is_guest=False,
|
is_guest=False,
|
||||||
|
@ -459,7 +465,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_reserved_threepid(self):
|
def test_reserved_threepid(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._max_mau_value = 1
|
self.auth_blocking._max_mau_value = 1
|
||||||
self.store.get_monthly_active_count = simple_async_mock(2)
|
self.store.get_monthly_active_count = simple_async_mock(2)
|
||||||
|
@ -476,7 +482,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
|
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
|
||||||
|
|
||||||
def test_hs_disabled(self):
|
def test_hs_disabled(self) -> None:
|
||||||
self.auth_blocking._hs_disabled = True
|
self.auth_blocking._hs_disabled = True
|
||||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
|
@ -486,7 +492,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
|
|
||||||
def test_hs_disabled_no_server_notices_user(self):
|
def test_hs_disabled_no_server_notices_user(self) -> None:
|
||||||
"""Check that 'hs_disabled_message' works correctly when there is no
|
"""Check that 'hs_disabled_message' works correctly when there is no
|
||||||
server_notices user.
|
server_notices user.
|
||||||
"""
|
"""
|
||||||
|
@ -503,7 +509,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
|
|
||||||
def test_server_notices_mxid_special_cased(self):
|
def test_server_notices_mxid_special_cased(self) -> None:
|
||||||
self.auth_blocking._hs_disabled = True
|
self.auth_blocking._hs_disabled = True
|
||||||
user = "@user:server"
|
user = "@user:server"
|
||||||
self.auth_blocking._server_notices_mxid = user
|
self.auth_blocking._server_notices_mxid = user
|
||||||
|
|
|
@ -14,40 +14,36 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes, EventContentFields
|
from synapse.api.constants import EduTypes, EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.api.presence import UserPresenceState
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.events.test_utils import MockEvent
|
||||||
|
|
||||||
user_localpart = "test_user"
|
user_localpart = "test_user"
|
||||||
|
|
||||||
|
|
||||||
def MockEvent(**kwargs):
|
|
||||||
if "event_id" not in kwargs:
|
|
||||||
kwargs["event_id"] = "fake_event_id"
|
|
||||||
if "type" not in kwargs:
|
|
||||||
kwargs["type"] = "fake_type"
|
|
||||||
if "content" not in kwargs:
|
|
||||||
kwargs["content"] = {}
|
|
||||||
return make_event_from_dict(kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.filtering = hs.get_filtering()
|
self.filtering = hs.get_filtering()
|
||||||
self.datastore = hs.get_datastores().main
|
self.datastore = hs.get_datastores().main
|
||||||
|
|
||||||
def test_errors_on_invalid_filters(self):
|
def test_errors_on_invalid_filters(self) -> None:
|
||||||
# See USER_FILTER_SCHEMA for the filter schema.
|
# See USER_FILTER_SCHEMA for the filter schema.
|
||||||
invalid_filters = [
|
invalid_filters: List[JsonDict] = [
|
||||||
# `account_data` must be a dictionary
|
# `account_data` must be a dictionary
|
||||||
{"account_data": "Hello World"},
|
{"account_data": "Hello World"},
|
||||||
# `event_fields` entries must not contain backslashes
|
# `event_fields` entries must not contain backslashes
|
||||||
|
@ -63,10 +59,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
with self.assertRaises(SynapseError):
|
with self.assertRaises(SynapseError):
|
||||||
self.filtering.check_valid_filter(filter)
|
self.filtering.check_valid_filter(filter)
|
||||||
|
|
||||||
def test_ignores_unknown_filter_fields(self):
|
def test_ignores_unknown_filter_fields(self) -> None:
|
||||||
# For forward compatibility, we must ignore unknown filter fields.
|
# For forward compatibility, we must ignore unknown filter fields.
|
||||||
# See USER_FILTER_SCHEMA for the filter schema.
|
# See USER_FILTER_SCHEMA for the filter schema.
|
||||||
filters = [
|
filters: List[JsonDict] = [
|
||||||
{"org.matrix.msc9999.future_option": True},
|
{"org.matrix.msc9999.future_option": True},
|
||||||
{"presence": {"org.matrix.msc9999.future_option": True}},
|
{"presence": {"org.matrix.msc9999.future_option": True}},
|
||||||
{"room": {"org.matrix.msc9999.future_option": True}},
|
{"room": {"org.matrix.msc9999.future_option": True}},
|
||||||
|
@ -76,8 +72,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
self.filtering.check_valid_filter(filter)
|
self.filtering.check_valid_filter(filter)
|
||||||
# Must not raise.
|
# Must not raise.
|
||||||
|
|
||||||
def test_valid_filters(self):
|
def test_valid_filters(self) -> None:
|
||||||
valid_filters = [
|
valid_filters: List[JsonDict] = [
|
||||||
{
|
{
|
||||||
"room": {
|
"room": {
|
||||||
"timeline": {"limit": 20},
|
"timeline": {"limit": 20},
|
||||||
|
@ -132,22 +128,22 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
except jsonschema.ValidationError as e:
|
except jsonschema.ValidationError as e:
|
||||||
self.fail(e)
|
self.fail(e)
|
||||||
|
|
||||||
def test_limits_are_applied(self):
|
def test_limits_are_applied(self) -> None:
|
||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_definition_types_works_with_literals(self):
|
def test_definition_types_works_with_literals(self) -> None:
|
||||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
|
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_types_works_with_wildcards(self):
|
def test_definition_types_works_with_wildcards(self) -> None:
|
||||||
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.*", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_types_works_with_unknowns(self):
|
def test_definition_types_works_with_unknowns(self) -> None:
|
||||||
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -156,24 +152,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_literals(self):
|
def test_definition_not_types_works_with_literals(self) -> None:
|
||||||
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_wildcards(self):
|
def test_definition_not_types_works_with_wildcards(self) -> None:
|
||||||
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
definition = {"not_types": ["m.room.message", "org.matrix.*"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_works_with_unknowns(self):
|
def test_definition_not_types_works_with_unknowns(self) -> None:
|
||||||
definition = {"not_types": ["m.*", "org.*"]}
|
definition = {"not_types": ["m.*", "org.*"]}
|
||||||
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_types_takes_priority_over_types(self):
|
def test_definition_not_types_takes_priority_over_types(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_types": ["m.*", "org.*"],
|
"not_types": ["m.*", "org.*"],
|
||||||
"types": ["m.room.message", "m.room.topic"],
|
"types": ["m.room.message", "m.room.topic"],
|
||||||
|
@ -181,35 +177,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_senders_works_with_literals(self):
|
def test_definition_senders_works_with_literals(self) -> None:
|
||||||
definition = {"senders": ["@flibble:wibble"]}
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_senders_works_with_unknowns(self):
|
def test_definition_senders_works_with_unknowns(self) -> None:
|
||||||
definition = {"senders": ["@flibble:wibble"]}
|
definition = {"senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_literals(self):
|
def test_definition_not_senders_works_with_literals(self) -> None:
|
||||||
definition = {"not_senders": ["@flibble:wibble"]}
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_works_with_unknowns(self):
|
def test_definition_not_senders_works_with_unknowns(self) -> None:
|
||||||
definition = {"not_senders": ["@flibble:wibble"]}
|
definition = {"not_senders": ["@flibble:wibble"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_senders_takes_priority_over_senders(self):
|
def test_definition_not_senders_takes_priority_over_senders(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
|
"senders": ["@kermit:muppets", "@misspiggy:muppets"],
|
||||||
|
@ -219,14 +215,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_literals(self):
|
def test_definition_rooms_works_with_literals(self) -> None:
|
||||||
definition = {"rooms": ["!secretbase:unknown"]}
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_rooms_works_with_unknowns(self):
|
def test_definition_rooms_works_with_unknowns(self) -> None:
|
||||||
definition = {"rooms": ["!secretbase:unknown"]}
|
definition = {"rooms": ["!secretbase:unknown"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -235,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_literals(self):
|
def test_definition_not_rooms_works_with_literals(self) -> None:
|
||||||
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
definition = {"not_rooms": ["!anothersecretbase:unknown"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -244,7 +240,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_works_with_unknowns(self):
|
def test_definition_not_rooms_works_with_unknowns(self) -> None:
|
||||||
definition = {"not_rooms": ["!secretbase:unknown"]}
|
definition = {"not_rooms": ["!secretbase:unknown"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -253,7 +249,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_not_rooms_takes_priority_over_rooms(self):
|
def test_definition_not_rooms_takes_priority_over_rooms(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_rooms": ["!secretbase:unknown"],
|
"not_rooms": ["!secretbase:unknown"],
|
||||||
"rooms": ["!secretbase:unknown"],
|
"rooms": ["!secretbase:unknown"],
|
||||||
|
@ -263,7 +259,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event(self):
|
def test_definition_combined_event(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets"],
|
"senders": ["@kermit:muppets"],
|
||||||
|
@ -279,7 +275,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_sender(self):
|
def test_definition_combined_event_bad_sender(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets"],
|
"senders": ["@kermit:muppets"],
|
||||||
|
@ -295,7 +291,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_room(self):
|
def test_definition_combined_event_bad_room(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets"],
|
"senders": ["@kermit:muppets"],
|
||||||
|
@ -311,7 +307,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_definition_combined_event_bad_type(self):
|
def test_definition_combined_event_bad_type(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"not_senders": ["@misspiggy:muppets"],
|
"not_senders": ["@misspiggy:muppets"],
|
||||||
"senders": ["@kermit:muppets"],
|
"senders": ["@kermit:muppets"],
|
||||||
|
@ -327,7 +323,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertFalse(Filter(self.hs, definition)._check(event))
|
self.assertFalse(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_labels(self):
|
def test_filter_labels(self) -> None:
|
||||||
definition = {"org.matrix.labels": ["#fun"]}
|
definition = {"org.matrix.labels": ["#fun"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -356,7 +352,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_not_labels(self):
|
def test_filter_not_labels(self) -> None:
|
||||||
definition = {"org.matrix.not_labels": ["#fun"]}
|
definition = {"org.matrix.not_labels": ["#fun"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -377,7 +373,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
|
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
|
||||||
def test_filter_rel_type(self):
|
def test_filter_rel_type(self) -> None:
|
||||||
definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
|
definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -407,7 +403,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
|
@unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
|
||||||
def test_filter_not_rel_type(self):
|
def test_filter_not_rel_type(self) -> None:
|
||||||
definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
|
definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
|
||||||
event = MockEvent(
|
event = MockEvent(
|
||||||
sender="@foo:bar",
|
sender="@foo:bar",
|
||||||
|
@ -436,15 +432,25 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertTrue(Filter(self.hs, definition)._check(event))
|
self.assertTrue(Filter(self.hs, definition)._check(event))
|
||||||
|
|
||||||
def test_filter_presence_match(self):
|
def test_filter_presence_match(self) -> None:
|
||||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
"""Check that filter_presence return events which matches the filter."""
|
||||||
|
user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
self.datastore.add_user_filter(
|
self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart, user_filter=user_filter_json
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
event = MockEvent(sender="@foo:bar", type="m.profile")
|
presence_states = [
|
||||||
events = [event]
|
UserPresenceState(
|
||||||
|
user_id="@foo:bar",
|
||||||
|
state="unavailable",
|
||||||
|
last_active_ts=0,
|
||||||
|
last_federation_update_ts=0,
|
||||||
|
last_user_sync_ts=0,
|
||||||
|
status_msg=None,
|
||||||
|
currently_active=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(
|
||||||
|
@ -452,23 +458,29 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_presence(events=events))
|
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||||
self.assertEqual(events, results)
|
self.assertEqual(presence_states, results)
|
||||||
|
|
||||||
def test_filter_presence_no_match(self):
|
def test_filter_presence_no_match(self) -> None:
|
||||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
"""Check that filter_presence does not return events rejected by the filter."""
|
||||||
|
user_filter_json = {"presence": {"not_senders": ["@foo:bar"]}}
|
||||||
|
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
self.datastore.add_user_filter(
|
self.datastore.add_user_filter(
|
||||||
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
user_localpart=user_localpart + "2", user_filter=user_filter_json
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
event = MockEvent(
|
presence_states = [
|
||||||
event_id="$asdasd:localhost",
|
UserPresenceState(
|
||||||
sender="@foo:bar",
|
user_id="@foo:bar",
|
||||||
type="custom.avatar.3d.crazy",
|
state="unavailable",
|
||||||
)
|
last_active_ts=0,
|
||||||
events = [event]
|
last_federation_update_ts=0,
|
||||||
|
last_user_sync_ts=0,
|
||||||
|
status_msg=None,
|
||||||
|
currently_active=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
user_filter = self.get_success(
|
user_filter = self.get_success(
|
||||||
self.filtering.get_user_filter(
|
self.filtering.get_user_filter(
|
||||||
|
@ -476,10 +488,10 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = self.get_success(user_filter.filter_presence(events=events))
|
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||||
self.assertEqual([], results)
|
self.assertEqual([], results)
|
||||||
|
|
||||||
def test_filter_room_state_match(self):
|
def test_filter_room_state_match(self) -> None:
|
||||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
self.datastore.add_user_filter(
|
self.datastore.add_user_filter(
|
||||||
|
@ -498,7 +510,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||||
self.assertEqual(events, results)
|
self.assertEqual(events, results)
|
||||||
|
|
||||||
def test_filter_room_state_no_match(self):
|
def test_filter_room_state_no_match(self) -> None:
|
||||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
self.datastore.add_user_filter(
|
self.datastore.add_user_filter(
|
||||||
|
@ -519,7 +531,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
results = self.get_success(user_filter.filter_room_state(events))
|
results = self.get_success(user_filter.filter_room_state(events))
|
||||||
self.assertEqual([], results)
|
self.assertEqual([], results)
|
||||||
|
|
||||||
def test_filter_rooms(self):
|
def test_filter_rooms(self) -> None:
|
||||||
definition = {
|
definition = {
|
||||||
"rooms": ["!allowed:example.com", "!excluded:example.com"],
|
"rooms": ["!allowed:example.com", "!excluded:example.com"],
|
||||||
"not_rooms": ["!excluded:example.com"],
|
"not_rooms": ["!excluded:example.com"],
|
||||||
|
@ -535,7 +547,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
|
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
|
||||||
|
|
||||||
def test_filter_relations(self):
|
def test_filter_relations(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
# An event without a relation.
|
# An event without a relation.
|
||||||
MockEvent(
|
MockEvent(
|
||||||
|
@ -551,9 +563,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
type="org.matrix.custom.event",
|
type="org.matrix.custom.event",
|
||||||
room_id="!foo:bar",
|
room_id="!foo:bar",
|
||||||
),
|
),
|
||||||
# Non-EventBase objects get passed through.
|
|
||||||
{},
|
|
||||||
]
|
]
|
||||||
|
jsondicts: List[JsonDict] = [{}]
|
||||||
|
|
||||||
# For the following tests we patch the datastore method (intead of injecting
|
# For the following tests we patch the datastore method (intead of injecting
|
||||||
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
|
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
|
||||||
|
@ -561,7 +572,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
# Filter for a particular sender.
|
# Filter for a particular sender.
|
||||||
definition = {"related_by_senders": ["@foo:bar"]}
|
definition = {"related_by_senders": ["@foo:bar"]}
|
||||||
|
|
||||||
async def events_have_relations(*args, **kwargs):
|
async def events_have_relations(*args: object, **kwargs: object) -> List[str]:
|
||||||
return ["$with_relation"]
|
return ["$with_relation"]
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
|
@ -572,9 +583,17 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
Filter(self.hs, definition)._check_event_relations(events)
|
Filter(self.hs, definition)._check_event_relations(events)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(filtered_events, events[1:])
|
# Non-EventBase objects get passed through.
|
||||||
|
filtered_jsondicts = list(
|
||||||
|
self.get_success(
|
||||||
|
Filter(self.hs, definition)._check_event_relations(jsondicts)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_add_filter(self):
|
self.assertEqual(filtered_events, events[1:])
|
||||||
|
self.assertEqual(filtered_jsondicts, [{}])
|
||||||
|
|
||||||
|
def test_add_filter(self) -> None:
|
||||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
|
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
|
@ -595,7 +614,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_filter(self):
|
def test_get_filter(self) -> None:
|
||||||
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
|
||||||
|
|
||||||
filter_id = self.get_success(
|
filter_id = self.get_success(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestRatelimiter(unittest.HomeserverTestCase):
|
class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
def test_allowed_via_can_do_action(self):
|
def test_allowed_via_can_do_action(self) -> None:
|
||||||
limiter = Ratelimiter(
|
limiter = Ratelimiter(
|
||||||
store=self.hs.get_datastores().main,
|
store=self.hs.get_datastores().main,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
|
@ -31,7 +31,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(20.0, time_allowed)
|
self.assertEqual(20.0, time_allowed)
|
||||||
|
|
||||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
token="fake_token",
|
token="fake_token",
|
||||||
id="foo",
|
id="foo",
|
||||||
|
@ -64,7 +64,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(20.0, time_allowed)
|
self.assertEqual(20.0, time_allowed)
|
||||||
|
|
||||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
def test_allowed_appservice_via_can_requester_do_action(self) -> None:
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
token="fake_token",
|
token="fake_token",
|
||||||
id="foo",
|
id="foo",
|
||||||
|
@ -97,7 +97,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(-1, time_allowed)
|
self.assertEqual(-1, time_allowed)
|
||||||
|
|
||||||
def test_allowed_via_ratelimit(self):
|
def test_allowed_via_ratelimit(self) -> None:
|
||||||
limiter = Ratelimiter(
|
limiter = Ratelimiter(
|
||||||
store=self.hs.get_datastores().main,
|
store=self.hs.get_datastores().main,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
|
@ -120,7 +120,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
limiter.ratelimit(None, key="test_id", _time_now_s=10)
|
limiter.ratelimit(None, key="test_id", _time_now_s=10)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_allowed_via_can_do_action_and_overriding_parameters(self):
|
def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
|
||||||
"""Test that we can override options of can_do_action that would otherwise fail
|
"""Test that we can override options of can_do_action that would otherwise fail
|
||||||
an action
|
an action
|
||||||
"""
|
"""
|
||||||
|
@ -169,7 +169,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(allowed)
|
self.assertTrue(allowed)
|
||||||
self.assertEqual(1.0, time_allowed)
|
self.assertEqual(1.0, time_allowed)
|
||||||
|
|
||||||
def test_allowed_via_ratelimit_and_overriding_parameters(self):
|
def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
|
||||||
"""Test that we can override options of the ratelimit method that would otherwise
|
"""Test that we can override options of the ratelimit method that would otherwise
|
||||||
fail an action
|
fail an action
|
||||||
"""
|
"""
|
||||||
|
@ -204,7 +204,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
|
limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_pruning(self):
|
def test_pruning(self) -> None:
|
||||||
limiter = Ratelimiter(
|
limiter = Ratelimiter(
|
||||||
store=self.hs.get_datastores().main,
|
store=self.hs.get_datastores().main,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
|
@ -223,7 +223,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertNotIn("test_id_1", limiter.actions)
|
self.assertNotIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
def test_db_user_override(self):
|
def test_db_user_override(self) -> None:
|
||||||
"""Test that users that have ratelimiting disabled in the DB aren't
|
"""Test that users that have ratelimiting disabled in the DB aren't
|
||||||
ratelimited.
|
ratelimited.
|
||||||
"""
|
"""
|
||||||
|
@ -250,7 +250,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
|
self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
|
||||||
|
|
||||||
def test_multiple_actions(self):
|
def test_multiple_actions(self) -> None:
|
||||||
limiter = Ratelimiter(
|
limiter = Ratelimiter(
|
||||||
store=self.hs.get_datastores().main,
|
store=self.hs.get_datastores().main,
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
|
|
|
@ -35,6 +35,8 @@ def MockEvent(**kwargs: Any) -> EventBase:
|
||||||
kwargs["event_id"] = "fake_event_id"
|
kwargs["event_id"] = "fake_event_id"
|
||||||
if "type" not in kwargs:
|
if "type" not in kwargs:
|
||||||
kwargs["type"] = "fake_type"
|
kwargs["type"] = "fake_type"
|
||||||
|
if "content" not in kwargs:
|
||||||
|
kwargs["content"] = {}
|
||||||
return make_event_from_dict(kwargs)
|
return make_event_from_dict(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue