diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index d5c395061..8864a921f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import ( - EventTypes, KnownRoomEventKeys, Membership, SearchConstraintTypes + EventTypes, Membership, ) from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event @@ -29,45 +29,6 @@ import logging logger = logging.getLogger(__name__) -KEYS_TO_ALLOWED_CONSTRAINT_TYPES = { - KnownRoomEventKeys.CONTENT_BODY: [SearchConstraintTypes.FTS], - KnownRoomEventKeys.CONTENT_MSGTYPE: [SearchConstraintTypes.EXACT], - KnownRoomEventKeys.CONTENT_NAME: [ - SearchConstraintTypes.FTS, - SearchConstraintTypes.EXACT, - SearchConstraintTypes.SUBSTRING, - ], - KnownRoomEventKeys.CONTENT_TOPIC: [SearchConstraintTypes.FTS], - KnownRoomEventKeys.SENDER: [SearchConstraintTypes.EXACT], - KnownRoomEventKeys.ORIGIN_SERVER_TS: [SearchConstraintTypes.RANGE], - KnownRoomEventKeys.ROOM_ID: [SearchConstraintTypes.EXACT], -} - - -class RoomConstraint(object): - def __init__(self, search_type, keys, value): - self.search_type = search_type - self.keys = keys - self.value = value - - @classmethod - def from_dict(cls, d): - search_type = d["type"] - keys = d["keys"] - - for key in keys: - if key not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES: - raise SynapseError(400, "Unrecognized key %r", key) - - if search_type not in KEYS_TO_ALLOWED_CONSTRAINT_TYPES[key]: - raise SynapseError( - 400, - "Disallowed constraint type %r for key %r", search_type, key - ) - - return cls(search_type, keys, d["value"]) - - class SearchHandler(BaseHandler): def __init__(self, hs): @@ -121,22 +82,20 @@ class SearchHandler(BaseHandler): @defer.inlineCallbacks def search(self, user, content): - constraint_dicts = content["search_categories"]["room_events"]["constraints"] - constraints = [RoomConstraint.from_dict(c)for c in constraint_dicts] - - fts = False - for c in constraints: - if c.search_type == SearchConstraintTypes.FTS: - if fts: - raise SynapseError(400, "Only one constraint can be FTS") - fts = True + try: + search_term = content["search_categories"]["room_events"]["search_term"] + keys = content["search_categories"]["room_events"]["keys"] + except KeyError: + raise SynapseError(400, "Invalid search query") rooms = yield self.store.get_rooms_for_user_where_membership_is( - user.to_string(), membership_list=[Membership.JOIN, Membership.LEAVE], + user.to_string(), + membership_list=[Membership.JOIN], + # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) room_ids = set(r.room_id for r in rooms) - rank_map, event_map = yield self.store.search_msgs(room_ids, constraints) + rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys) allowed_events = yield self._filter_events_for_client( user.to_string(), event_map.values() diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 5843f8087..7a30ce25e 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -16,38 +16,28 @@ from twisted.internet import defer from _base import SQLBaseStore -from synapse.api.constants import KnownRoomEventKeys, SearchConstraintTypes from synapse.storage.engines import PostgresEngine class SearchStore(SQLBaseStore): @defer.inlineCallbacks - def search_msgs(self, room_ids, constraints): + def search_msgs(self, room_ids, search_term, keys): clauses = [] args = [] - fts = None clauses.append( "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) ) args.extend(room_ids) - for c in constraints: - local_clauses = [] - if c.search_type == SearchConstraintTypes.FTS: - fts = c.value - for key in c.keys: - local_clauses.append("key = ?") - args.append(key) - elif c.search_type == SearchConstraintTypes.EXACT: - for key in c.keys: - if key == KnownRoomEventKeys.ROOM_ID: - for value in c.value: - local_clauses.append("room_id = ?") - args.append(value) - clauses.append( - "(%s)" % (" OR ".join(local_clauses),) - ) + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -67,7 +57,7 @@ class SearchStore(SQLBaseStore): sql += " ORDER BY rank DESC" results = yield self._execute( - "search_msgs", self.cursor_to_dict, sql, *([fts] + args) + "search_msgs", self.cursor_to_dict, sql, *([search_term] + args) ) events = yield self._get_events([r["event_id"] for r in results])