0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-05 10:48:46 +02:00

Merge pull request #922 from matrix-org/erikj/file_api2

Feature: Add filter to /messages. Add 'contains_url' to filter.
This commit is contained in:
Erik Johnston 2016-07-20 10:40:48 +01:00 committed by GitHub
commit 1e2a7f18a1
8 changed files with 246 additions and 15 deletions

View file

@ -191,6 +191,17 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@ -209,9 +220,10 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@ -225,15 +237,20 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = self.filter_json.get(not_name, []) disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = self.filter_json.get(name, None) allowed_values = getattr(self, name)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View file

@ -66,7 +66,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, event_filter=None):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -75,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -129,8 +129,13 @@ class MessageHandler(BaseHandler):
room_id, max_topo room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield self.store.paginate_room_events(
requester.user, source_config, room_id room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -144,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, self.store,
user_id, user_id,

View file

@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
import logging import logging
import urllib import urllib
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event,
event_filter=event_filter,
) )
defer.returnValue((200, msgs)) defer.returnValue((200, msgs))

View file

@ -152,6 +152,7 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs): def __init__(self, hs):
super(EventsStore, self).__init__(hs) super(EventsStore, self).__init__(hs)
@ -159,6 +160,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
) )
self.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
@ -576,6 +581,11 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"), "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts), "origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(), "received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],
@ -1115,6 +1125,78 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages) ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
update_rows = []
for row in rows:
try:
event_id = row[1]
event_json = json.loads(row[2])
sender = event_json["sender"]
content = event_json["content"]
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
update_rows.append((sender, contains_url, event_id))
sql = (
"UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
)
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size): def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]

View file

@ -0,0 +1,60 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import get_statements
import logging
import ujson
logger = logging.getLogger(__name__)
ALTER_TABLE = """
ALTER TABLE events ADD COLUMN sender TEXT;
ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
"""
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
cur.execute("SELECT MIN(stream_ordering) FROM events")
rows = cur.fetchall()
min_stream_id = rows[0][0]
cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
)
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_fields_sender_url", progress_json))
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -95,6 +95,54 @@ def upper_bound(token, engine, inclusive=True):
) )
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
if not event_filter:
return "", []
clauses = []
args = []
if event_filter.types:
clauses.append(
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
args.extend(event_filter.types)
for typ in event_filter.not_types:
clauses.append("type != ?")
args.append(typ)
if event_filter.senders:
clauses.append(
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
clauses.append("sender != ?")
args.append(sender)
if event_filter.rooms:
clauses.append(
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
)
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
clauses.append("room_id != ?")
args.append(room_id)
if event_filter.contains_url:
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
return " AND ".join(clauses), args
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0): def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@ -320,7 +368,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1): direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
@ -344,6 +392,12 @@ class StreamStore(SQLBaseStore):
RoomStreamToken.parse(to_key), self.database_engine RoomStreamToken.parse(to_key), self.database_engine
)) ))
filter_clause, filter_args = filter_to_clause(event_filter)
if filter_clause:
bounds += " AND " + filter_clause
args.extend(filter_args)
if int(limit) > 0: if int(limit) > 0:
args.append(int(limit)) args.append(int(limit))
limit_str = " LIMIT ?" limit_str = " LIMIT ?"

View file

@ -30,6 +30,7 @@ class EventInjector:
def create_room(self, room): def create_room(self, room):
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": EventTypes.Create, "type": EventTypes.Create,
"sender": "",
"room_id": room.to_string(), "room_id": room.to_string(),
"content": {}, "content": {},
}) })

View file

@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_count_daily_messages(self): def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting") yield self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100 self.hs.clock.now = 100
@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase):
# it isn't old enough. # it isn't old enough.
count = yield self.store.count_daily_messages() count = yield self.store.count_daily_messages()
self.assertIsNone(count) self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now) yield self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today. # Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!") yield self.event_injector.inject_message(room, user, "Yeah they are!")
@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += 60 * 60 * 24 self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages() count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently. # Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?") yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22 self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages() count = yield self.store.count_daily_messages()
self.assertIsNone(count) self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now) yield self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago # Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.") yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26 self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages() count = yield self.store.count_daily_messages()
self.assertIsNone(count) self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now) yield self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something # And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.") yield self.event_injector.inject_message(room, user, "Indeed.")
@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += (60 * 60 * 24) + 50 self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages() count = yield self.store.count_daily_messages()
self.assertEqual(3, count) self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now) yield self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_last_stream_token(self): def _get_last_stream_token(self):