From 7be0f6594e2a6dd7c3dc745eb856025276ec7d1f Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 11 Feb 2015 15:53:56 +0000 Subject: [PATCH] First step of making user_rooms_intersect() faster - implement in intersection logic in Python code terms of a DB query that is cacheable per user --- synapse/storage/roommember.py | 36 ++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index c69dd995c..d490a374e 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -240,28 +240,30 @@ class RoomMemberStore(SQLBaseStore): results = self._parse_events_txn(txn, rows) return results + @defer.inlineCallbacks def user_rooms_intersect(self, user_id_list): """ Checks whether all the users whose IDs are given in a list share a room. + + This is a "hot path" function that's called a lot, e.g. by presence for + generating the event stream. """ - def interaction(txn): - user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list)) - sql = ( - "SELECT m.room_id FROM room_memberships as m " - "INNER JOIN current_state_events as c " - "ON m.event_id = c.event_id " - "WHERE m.membership = 'join' " - "AND (%(clause)s) " - # TODO(paul): We've got duplicate rows in the database somewhere - # so we have to DISTINCT m.user_id here - "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?" - ) % {"clause": user_list_clause} + if len(user_id_list) < 2: + defer.returnValue(True) - args = list(user_id_list) - args.append(len(user_id_list)) + deferreds = [ + self.get_rooms_for_user_where_membership_is( + u, membership_list=[Membership.JOIN], + ) + for u in user_id_list + ] - txn.execute(sql, args) + results = yield defer.DeferredList(deferreds) - return len(txn.fetchall()) > 0 + # A list of sets of strings giving room IDs for each user + room_id_lists = [set([r.room_id for r in result[1]]) for result in results] - return self.runInteraction("user_rooms_intersect", interaction) + # There isn't a setintersection(*list_of_sets) + ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 + + defer.returnValue(ret)