0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-01 04:09:07 +01:00

Merge branch 'client_v2_filter' into client_v2_sync

Conflicts:
	synapse/rest/client/v2_alpha/__init__.py
This commit is contained in:
Mark Haines 2015-01-29 14:58:00 +00:00
commit 396a67a09a
15 changed files with 1100 additions and 20 deletions

246
synapse/api/filtering.py Normal file
View file

@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
return self.store.get_user_filter(user_localpart, filter_id)
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
def filter_public_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["public_user_data"]
)
def filter_private_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["private_user_data"]
)
def filter_room_state(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "state"]
)
def filter_room_events(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "events"]
)
def filter_room_ephemeral(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "ephemeral"]
)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
@defer.inlineCallbacks
def _filter_on_key(self, events, user, filter_id, keys):
filter_json = yield self.get_user_filter(user.localpart, filter_id)
if not filter_json:
defer.returnValue(events)
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
defer.returnValue(self._filter_with_definition(events, definition))
except KeyError:
# return all events if definition isn't specified.
defer.returnValue(events)
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition["rooms"] if "rooms" in definition else None
reject_rooms = (
definition["not_rooms"] if "not_rooms" in definition else None
)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = (
definition["senders"] if "senders" in definition else None
)
reject_senders = (
definition["not_senders"] if "not_senders" in definition else None
)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type
def _check_valid_filter(self, user_filter):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter:
self._check_definition(user_filter[key])
if "room" in user_filter:
for key in room_level_definitions:
if key in user_filter["room"]:
self._check_definition(user_filter["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
try:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
except KeyError:
pass # format is optional
try:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
except KeyError:
pass # select is optional
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")

View file

@ -138,7 +138,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@ -184,7 +184,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
if 'device' in spec:
rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
@ -215,7 +215,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is

View file

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
user = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)

View file

@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
sync,
filter
)
from synapse.http.server import JsonResource
@ -31,3 +31,4 @@ class ClientV2AlphaRestResource(JsonResource):
@staticmethod
def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_v2_pattern
import json
import logging
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
try:
filter_id = int(filter_id)
except:
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)
defer.returnValue((200, filter))
except KeyError:
raise SynapseError(400, "No such filter")
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
try:
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
def register_servlets(hs, http_server):
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View file

@ -33,6 +33,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object):
@ -81,6 +82,7 @@ class BaseHomeServer(object):
'keyring',
'pusherpool',
'event_builder_factory',
'filtering',
]
def __init__(self, hostname, **kwargs):
@ -200,5 +202,8 @@ class HomeServer(BaseHomeServer):
hostname=self.hostname,
)
def build_filtering(self):
return Filtering(self)
def build_pusherpool(self):
return PusherPool(self)

View file

@ -32,9 +32,9 @@ from .event_federation import EventFederationStore
from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
@ -64,6 +64,7 @@ SCHEMAS = [
"event_signatures",
"pusher",
"media_repository",
"filtering",
]
@ -85,6 +86,7 @@ class DataStore(RoomMemberStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore,
MediaRepositoryStore,
FilteringStore,
PusherStore,
PushRuleStore
):

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
import json
# TODO(paul)
_filters_for_user = {}
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json",
allow_none=False,
)
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View file

@ -175,14 +175,17 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, new_rule.values())
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id):
yield self._simple_delete_one(
PushRuleTable.table_name,
{
'user_name': user_name,
'rule_id': rule_id
}
)
def delete_push_rule(self, user_name, rule_id, **kwargs):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_name (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(PushRuleTable.table_name, kwargs)
class RuleNotFoundException(Exception):

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string):
try:
if string[0] == 's':
return cls(None, int(string[1:]))
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(int(parts[1]), int(parts[0]))
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(None, int(string[1:]))
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

506
tests/api/test_filtering.py Normal file
View file

@ -0,0 +1,506 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from tests import unittest
from twisted.internet import defer
from mock import Mock, NonCallableMock
from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
MockKey
)
from synapse.server import HomeServer
from synapse.types import UserID
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.mock_federation_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
hs = HomeServer("test",
db_pool=db_pool,
handlers=None,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
def test_definition_types_works_with_literals(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_types_works_with_wildcards(self):
definition = {
"types": ["m.*", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_types_works_with_unknowns(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="now.for.something.completely.different",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_literals(self):
definition = {
"not_types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_wildcards(self):
definition = {
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_works_with_unknowns(self):
definition = {
"not_types": ["m.*", "org.*"]
}
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_types_takes_priority_over_types(self):
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_senders_works_with_literals(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_senders_works_with_unknowns(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_literals(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_unknowns(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
}
event = MockEvent(
sender="@misspiggy:muppets",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_rooms_works_with_literals(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_rooms_works_with_unknowns(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_literals(self):
definition = {
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_unknowns(self):
definition = {
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertTrue(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_sender(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@misspiggy:muppets", # nope
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_room(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
def test_definition_combined_event_bad_type(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filtering._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_room_state(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_room_state(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
self.assertEquals(filter_id, 0)
self.assertEquals(user_filter,
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
))
)
@defer.inlineCallbacks
def test_get_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
self.assertEquals(filter, user_filter)

View file

@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test",
db_pool=None,
datastore=Mock(spec=[
"insert_client_ip",
]),
datastore=self.make_datastore_mock(),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
@ -58,3 +56,8 @@ class V2AlphaRestTestCase(unittest.TestCase):
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
return Mock(spec=[
"insert_client_ip",
])

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from mock import Mock
from . import V2AlphaRestTestCase
from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError
class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test"
TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
self._user_filters = {}
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id):
if user_localpart not in self._user_filters:
raise StoreError(404, "No user")
filters = self._user_filters[user_localpart]
if filter_id >= len(filters):
raise StoreError(404, "No filter")
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger("POST",
"/user/%s/filter" % (self.USER_ID),
'{"type": ["m.*"]}'
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
self.assertIn("apple", self._user_filters)
self.assertEquals(len(self._user_filters["apple"]), 1)
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
@defer.inlineCallbacks
def test_get_filter(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response)
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/2" % (self.USER_ID), None
)
self.assertEquals(404, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(404, code)