0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-07 07:08:57 +01:00

Merge branch 'client_v2_filter' into client_v2_sync

Conflicts:
	synapse/rest/client/v2_alpha/__init__.py
This commit is contained in:
Mark Haines 2015-01-29 14:58:00 +00:00
commit 396a67a09a
15 changed files with 1100 additions and 20 deletions

246
synapse/api/filtering.py Normal file
View file

@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
return self.store.get_user_filter(user_localpart, filter_id)
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
def filter_public_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["public_user_data"]
)
def filter_private_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["private_user_data"]
)
def filter_room_state(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "state"]
)
def filter_room_events(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "events"]
)
def filter_room_ephemeral(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "ephemeral"]
)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
@defer.inlineCallbacks
def _filter_on_key(self, events, user, filter_id, keys):
filter_json = yield self.get_user_filter(user.localpart, filter_id)
if not filter_json:
defer.returnValue(events)
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
defer.returnValue(self._filter_with_definition(events, definition))
except KeyError:
# return all events if definition isn't specified.
defer.returnValue(events)
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition["rooms"] if "rooms" in definition else None
reject_rooms = (
definition["not_rooms"] if "not_rooms" in definition else None
)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = (
definition["senders"] if "senders" in definition else None
)
reject_senders = (
definition["not_senders"] if "not_senders" in definition else None
)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type
def _check_valid_filter(self, user_filter):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter:
self._check_definition(user_filter[key])
if "room" in user_filter:
for key in room_level_definitions:
if key in user_filter["room"]:
self._check_definition(user_filter["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
try:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
except KeyError:
pass # format is optional
try:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
except KeyError:
pass # select is optional
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")

View file

@ -138,7 +138,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -184,7 +184,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
if 'device' in spec: if 'device' in spec:
rules = yield self.hs.get_datastore().get_push_rules_for_user_name( rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
@ -215,7 +215,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is

View file

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View file

@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import ( from . import (
sync, sync,
filter
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -31,3 +31,4 @@ class ClientV2AlphaRestResource(JsonResource):
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_v2_pattern
import json
import logging
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
try:
filter_id = int(filter_id)
except:
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)
defer.returnValue((200, filter))
except KeyError:
raise SynapseError(400, "No such filter")
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
try:
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
def register_servlets(hs, http_server):
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View file

@ -33,6 +33,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object): class BaseHomeServer(object):
@ -81,6 +82,7 @@ class BaseHomeServer(object):
'keyring', 'keyring',
'pusherpool', 'pusherpool',
'event_builder_factory', 'event_builder_factory',
'filtering',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -200,5 +202,8 @@ class HomeServer(BaseHomeServer):
hostname=self.hostname, hostname=self.hostname,
) )
def build_filtering(self):
return Filtering(self)
def build_pusherpool(self): def build_pusherpool(self):
return PusherPool(self) return PusherPool(self)

View file

@ -32,9 +32,9 @@ from .event_federation import EventFederationStore
from .pusher import PusherStore from .pusher import PusherStore
from .push_rule import PushRuleStore from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -64,6 +64,7 @@ SCHEMAS = [
"event_signatures", "event_signatures",
"pusher", "pusher",
"media_repository", "media_repository",
"filtering",
] ]
@ -85,6 +86,7 @@ class DataStore(RoomMemberStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
FilteringStore,
PusherStore, PusherStore,
PushRuleStore PushRuleStore
): ):

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
import json
# TODO(paul)
_filters_for_user = {}
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json",
allow_none=False,
)
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View file

@ -175,14 +175,17 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, new_rule.values()) txn.execute(sql, new_rule.values())
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id): def delete_push_rule(self, user_name, rule_id, **kwargs):
yield self._simple_delete_one( """
PushRuleTable.table_name, Delete a push rule. Args specify the row to be deleted and can be
{ any of the columns in the push_rule table, but below are the
'user_name': user_name, standard ones
'rule_id': rule_id
} Args:
) user_name (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(PushRuleTable.table_name, kwargs)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string): def parse(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
if string[0] == 't': if string[0] == 't':
parts = string[1:].split('-', 1) parts = string[1:].split('-', 1)
return cls(int(parts[1]), int(parts[0])) return cls(topological=int(parts[0]), stream=int(parts[1]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string): def parse_stream_token(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))

506
tests/api/test_filtering.py Normal file
View file

@ -0,0 +1,506 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from tests import unittest
from twisted.internet import defer
from mock import Mock, NonCallableMock
from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
MockKey
)
from synapse.server import HomeServer
from synapse.types import UserID
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.mock_federation_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
hs = HomeServer("test",
db_pool=db_pool,
handlers=None,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
def test_definition_types_works_with_literals(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_types_works_with_wildcards(self):
definition = {
"types": ["m.*", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_types_works_with_unknowns(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="now.for.something.completely.different",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_literals(self):
definition = {
"not_types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_wildcards(self):
definition = {
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_unknowns(self):
definition = {
"not_types": ["m.*", "org.*"]
}
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_takes_priority_over_types(self):
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_senders_works_with_literals(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_senders_works_with_unknowns(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_literals(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_unknowns(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
}
event = MockEvent(
sender="@misspiggy:muppets",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_rooms_works_with_literals(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_rooms_works_with_unknowns(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_literals(self):
definition = {
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_unknowns(self):
definition = {
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_sender(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@misspiggy:muppets", # nope
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_room(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_type(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_room_state(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_room_state(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
self.assertEquals(filter_id, 0)
self.assertEquals(user_filter,
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
))
)
@defer.inlineCallbacks
def test_get_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
self.assertEquals(filter, user_filter)

View file

@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
datastore=Mock(spec=[ datastore=self.make_datastore_mock(),
"insert_client_ip",
]),
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
@ -58,3 +56,8 @@ class V2AlphaRestTestCase(unittest.TestCase):
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
return Mock(spec=[
"insert_client_ip",
])

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from mock import Mock
from . import V2AlphaRestTestCase
from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError
class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test"
TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
self._user_filters = {}
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id):
if user_localpart not in self._user_filters:
raise StoreError(404, "No user")
filters = self._user_filters[user_localpart]
if filter_id >= len(filters):
raise StoreError(404, "No filter")
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger("POST",
"/user/%s/filter" % (self.USER_ID),
'{"type": ["m.*"]}'
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
self.assertIn("apple", self._user_filters)
self.assertEquals(len(self._user_filters["apple"]), 1)
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
@defer.inlineCallbacks
def test_get_filter(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response)
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/2" % (self.USER_ID), None
)
self.assertEquals(404, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(404, code)