0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 16:32:24 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into rejections_storage

Conflicts:
	synapse/storage/__init__.py
This commit is contained in:
Erik Johnston 2015-01-30 14:57:33 +00:00
commit fdd2ac495a
14 changed files with 1140 additions and 11 deletions

229
synapse/api/filtering.py Normal file
View file

@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
def _check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter_json(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type

View file

@ -24,6 +24,7 @@ import baserules
import logging import logging
import fnmatch import fnmatch
import json import json
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,6 +35,8 @@ class Pusher(object):
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify'] DEFAULT_ACTIONS = ['notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, instance_handle, user_name, app_id, def __init__(self, _hs, instance_handle, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
@ -88,11 +91,21 @@ class Pusher(object):
member_events_for_room = yield self.store.get_current_state( member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'], room_id=ev['room_id'],
event_type='m.room.member', event_type='m.room.member',
state_key=self.user_name state_key=None
) )
my_display_name = None my_display_name = None
if len(member_events_for_room) > 0: room_member_count = 0
my_display_name = member_events_for_room[0].content['displayname'] for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules: for r in rules:
matches = True matches = True
@ -102,7 +115,8 @@ class Pusher(object):
for c in conditions: for c in conditions:
matches &= self._event_fulfills_condition( matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name ev, c, display_name=my_display_name,
room_member_count=room_member_count
) )
# ignore rules with no actions (we have an explict 'dont_notify' # ignore rules with no actions (we have an explict 'dont_notify'
if len(actions) == 0: if len(actions) == 0:
@ -116,7 +130,7 @@ class Pusher(object):
defer.returnValue(Pusher.DEFAULT_ACTIONS) defer.returnValue(Pusher.DEFAULT_ACTIONS)
def _event_fulfills_condition(self, ev, condition, display_name): def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
if 'pattern' not in condition: if 'pattern' not in condition:
logger.warn("event_match condition with no pattern") logger.warn("event_match condition with no pattern")
@ -138,9 +152,35 @@ class Pusher(object):
# the event stream. # the event stream.
if 'content' not in ev or 'body' not in ev['content']: if 'content' not in ev or 'body' not in ev['content']:
return False return False
if not display_name:
return False
return fnmatch.fnmatch( return fnmatch.fnmatch(
ev['content']['body'].upper(), "*%s*" % (display_name.upper(),) ev['content']['body'].upper(), "*%s*" % (display_name.upper(),)
) )
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = Pusher.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else: else:
return True return True

View file

@ -32,4 +32,18 @@ def make_base_rules(user_name):
} }
] ]
}, },
{
'conditions': [
{
'kind': 'room_member_count',
'is': '2'
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
}
] ]

View file

@ -14,6 +14,11 @@
# limitations under the License. # limitations under the License.
from . import (
filter
)
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -26,4 +31,4 @@ class ClientV2AlphaRestResource(JsonResource):
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
pass filter.register_servlets(hs, client_resource)

View file

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_v2_pattern
import json
import logging
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
try:
filter_id = int(filter_id)
except:
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)
defer.returnValue((200, filter.filter_json))
except KeyError:
raise SynapseError(400, "No such filter")
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
try:
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
def register_servlets(hs, http_server):
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View file

@ -33,6 +33,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object): class BaseHomeServer(object):
@ -81,6 +82,7 @@ class BaseHomeServer(object):
'keyring', 'keyring',
'pusherpool', 'pusherpool',
'event_builder_factory', 'event_builder_factory',
'filtering',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -200,5 +202,8 @@ class HomeServer(BaseHomeServer):
hostname=self.hostname, hostname=self.hostname,
) )
def build_filtering(self):
return Filtering(self)
def build_pusherpool(self): def build_pusherpool(self):
return PusherPool(self) return PusherPool(self)

View file

@ -296,7 +296,8 @@ class StateHandler(object):
except AuthError: except AuthError:
pass pass
# Oh dear. # Use the last event (the one with the least depth) if they all fail
# the auth check.
return event return event
def _ordered_events(self, events): def _ordered_events(self, events):

View file

@ -32,10 +32,14 @@ from .event_federation import EventFederationStore
from .pusher import PusherStore from .pusher import PusherStore
from .push_rule import PushRuleStore from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
<<<<<<< HEAD
from .rejections import RejectionsStore from .rejections import RejectionsStore
=======
>>>>>>> 471c47441d0c188e845b75c8f446c44899fdcfe7
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -65,6 +69,7 @@ SCHEMAS = [
"event_signatures", "event_signatures",
"pusher", "pusher",
"media_repository", "media_repository",
"filtering",
] ]
@ -87,6 +92,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
FilteringStore,
PusherStore, PusherStore,
PushRuleStore PushRuleStore
): ):
@ -380,9 +386,12 @@ class DataStore(RoomMemberStore, RoomStore,
"redacted": del_sql, "redacted": del_sql,
} }
if event_type: if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? " sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key) args = (room_id, event_type, state_key)
elif event_type:
sql += " AND s.type = ?"
args = (room_id, event_type)
else: else:
args = (room_id, ) args = (room_id, )

View file

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
import json
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json",
allow_none=False,
)
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

512
tests/api/test_filtering.py Normal file
View file

@ -0,0 +1,512 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from tests import unittest
from twisted.internet import defer
from mock import Mock, NonCallableMock
from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
MockKey
)
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.api.filtering import Filter
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.mock_federation_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
hs = HomeServer("test",
db_pool=db_pool,
handlers=None,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
self.filtering = hs.get_filtering()
self.filter = Filter({})
self.datastore = hs.get_datastore()
def test_definition_types_works_with_literals(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_wildcards(self):
definition = {
"types": ["m.*", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_unknowns(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="now.for.something.completely.different",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_literals(self):
definition = {
"not_types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_wildcards(self):
definition = {
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_unknowns(self):
definition = {
"not_types": ["m.*", "org.*"]
}
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_takes_priority_over_types(self):
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_literals(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_unknowns(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_literals(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_unknowns(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
}
event = MockEvent(
sender="@misspiggy:muppets",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_literals(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_unknowns(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_literals(self):
definition = {
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_unknowns(self):
definition = {
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_sender(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@misspiggy:muppets", # nope
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_room(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_type(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_public_user_data(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_public_user_data(events=events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_room_state(events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
self.assertEquals(filter_id, 0)
self.assertEquals(user_filter_json,
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
))
)
@defer.inlineCallbacks
def test_get_filter(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
self.assertEquals(filter.filter_json, user_filter_json)

View file

@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
datastore=Mock(spec=[ datastore=self.make_datastore_mock(),
"insert_client_ip",
]),
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
@ -53,8 +51,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
return Mock(spec=[
"insert_client_ip",
])

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from mock import Mock
from . import V2AlphaRestTestCase
from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError
class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test"
TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
self._user_filters = {}
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id):
if user_localpart not in self._user_filters:
raise StoreError(404, "No user")
filters = self._user_filters[user_localpart]
if filter_id >= len(filters):
raise StoreError(404, "No filter")
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger("POST",
"/user/%s/filter" % (self.USER_ID),
'{"type": ["m.*"]}'
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
self.assertIn("apple", self._user_filters)
self.assertEquals(len(self._user_filters["apple"]), 1)
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
@defer.inlineCallbacks
def test_get_filter(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response)
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/2" % (self.USER_ID), None
)
self.assertEquals(404, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(404, code)