Sanitize filters

This commit is contained in:
Erik Johnston 2016-01-22 10:41:30 +00:00
parent 297eded261
commit 975903ae17
3 changed files with 40 additions and 34 deletions

View file

@ -28,14 +28,14 @@ class Filtering(object):
return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
def _check_valid_filter(self, user_filter_json):
def check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
@ -129,52 +129,55 @@ class Filtering(object):
class FilterCollection(object):
def __init__(self, filter_json):
self.filter_json = filter_json
self._filter_json = filter_json
room_filter_json = self.filter_json.get("room", {})
room_filter_json = self._filter_json.get("room", {})
self.room_filter = Filter({
self._room_filter = Filter({
k: v for k, v in room_filter_json.items()
if k in ("rooms", "not_rooms")
})
self.room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self.room_state_filter = Filter(room_filter_json.get("state", {}))
self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self.room_account_data = Filter(room_filter_json.get("account_data", {}))
self.presence_filter = Filter(self.filter_json.get("presence", {}))
self.account_data = Filter(self.filter_json.get("account_data", {}))
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
self.include_leave = self.filter_json.get("room", {}).get(
self.include_leave = filter_json.get("room", {}).get(
"include_leave", False
)
def get_filter_json(self):
return self._filter_json
def timeline_limit(self):
return self.room_timeline_filter.limit()
return self._room_timeline_filter.limit()
def presence_limit(self):
return self.presence_filter.limit()
return self._presence_filter.limit()
def ephemeral_limit(self):
return self.room_ephemeral_filter.limit()
return self._room_ephemeral_filter.limit()
def filter_presence(self, events):
return self.presence_filter.filter(events)
return self._presence_filter.filter(events)
def filter_account_data(self, events):
return self.account_data.filter(events)
return self._account_data.filter(events)
def filter_room_state(self, events):
return self.room_state_filter.filter(self.room_filter.filter(events))
return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(self.room_filter.filter(events))
return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(self.room_filter.filter(events))
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data(self, events):
return self.room_account_data.filter(self.room_filter.filter(events))
return self._room_account_data.filter(self._room_filter.filter(events))
class Filter(object):
@ -258,3 +261,6 @@ def _matches_wildcard(actual_value, filter_value):
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})

View file

@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
filter_id=filter_id,
)
defer.returnValue((200, filter.filter_json))
defer.returnValue((200, filter.get_filter_json()))
except KeyError:
raise SynapseError(400, "No such filter")

View file

@ -24,7 +24,7 @@ from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
from synapse.api.filtering import FilterCollection
from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError
from ._base import client_v2_patterns
@ -113,20 +113,20 @@ class SyncRestServlet(RestServlet):
)
)
if filter_id and filter_id.startswith('{'):
if filter_id:
if filter_id.startswith('{'):
try:
filter_object = json.loads(filter_id)
except:
raise SynapseError(400, "Invalid filter JSON")
self.filtering._check_valid_filter(filter_object)
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
else:
try:
filter = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except:
filter = FilterCollection({})
else:
filter = DEFAULT_FILTER_COLLECTION
sync_config = SyncConfig(
user=user,