From 8d73cd502bd8ee6903c81f20f79fe5e1509692e3 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 1 Apr 2016 14:06:00 +0100
Subject: [PATCH] Add concurrently_execute function

---
 synapse/handlers/message.py | 10 +---
 synapse/handlers/room.py    | 17 +++----
 synapse/handlers/sync.py    | 96 ++++++++++++++++---------------------
 synapse/util/async.py       | 32 ++++++++++++-
 4 files changed, 81 insertions(+), 74 deletions(-)

diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5c50c611b..0bb111d04 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
+from synapse.util.async import concurrently_execute
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.types import UserID, RoomStreamToken, StreamToken
 
@@ -556,14 +557,7 @@ class MessageHandler(BaseHandler):
             except:
                 logger.exception("Failed to get snapshot")
 
-        # Only do N rooms at once
-        n = 5
-        d_list = [handle_room(e) for e in room_list]
-        for i in range(0, len(d_list), n):
-            yield defer.gatherResults(
-                d_list[i:i + n],
-                consumeErrors=True
-            ).addErrback(unwrapFirstError)
+        yield concurrently_execute(handle_room, room_list, 10)
 
         account_data_events = []
         for account_data_type, content in account_data.items():
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee99ded21..3e1d9282d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -23,7 +23,8 @@ from synapse.api.constants import (
     EventTypes, JoinRules, RoomCreationPreset,
 )
 from synapse.api.errors import AuthError, StoreError, SynapseError
-from synapse.util import stringutils, unwrapFirstError
+from synapse.util import stringutils
+from synapse.util.async import concurrently_execute
 from synapse.util.logcontext import preserve_context_over_fn
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -368,6 +369,8 @@ class RoomListHandler(BaseHandler):
     def _get_public_room_list(self):
         room_ids = yield self.store.get_public_room_ids()
 
+        results = []
+
         @defer.inlineCallbacks
         def handle_room(room_id):
             aliases = yield self.store.get_aliases_for_room(room_id)
@@ -428,18 +431,12 @@ class RoomListHandler(BaseHandler):
             joined_users = yield self.store.get_users_in_room(room_id)
             result["num_joined_members"] = len(joined_users)
 
-            defer.returnValue(result)
+            results.append(result)
 
-        result = []
-        for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
-            chunk_result = yield defer.gatherResults([
-                handle_room(room_id)
-                for room_id in chunk
-            ], consumeErrors=True).addErrback(unwrapFirstError)
-            result.extend(v for v in chunk_result if v)
+        yield concurrently_execute(handle_room, room_ids, 10)
 
         # FIXME (erikj): START is no longer a valid value
-        defer.returnValue({"start": "START", "end": "END", "chunk": result})
+        defer.returnValue({"start": "START", "end": "END", "chunk": results})
 
 
 class RoomContextHandler(BaseHandler):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 06098f899..e38fe1ef9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -17,8 +17,8 @@ from ._base import BaseHandler
 
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import Membership, EventTypes
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.async import concurrently_execute
+from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.push.clientformat import format_push_rules_for_user
@@ -250,64 +250,50 @@ class SyncHandler(BaseHandler):
         joined = []
         invited = []
         archived = []
-        deferreds = []
 
         user_id = sync_config.user.to_string()
 
-        def _should_include_room(event):
-            # Always send down rooms we were banned or kicked from.
-            if not sync_config.filter_collection.include_leave:
-                if event.membership == Membership.LEAVE:
-                    if user_id == event.sender:
-                        return False
-            return True
+        @defer.inlineCallbacks
+        def _generate_room_entry(event):
+            if event.membership == Membership.JOIN:
+                room_result = yield self.full_state_sync_for_joined_room(
+                    room_id=event.room_id,
+                    sync_config=sync_config,
+                    now_token=now_token,
+                    timeline_since_token=timeline_since_token,
+                    ephemeral_by_room=ephemeral_by_room,
+                    tags_by_room=tags_by_room,
+                    account_data_by_room=account_data_by_room,
+                )
+                joined.append(room_result)
+            elif event.membership == Membership.INVITE:
+                invite = yield self.store.get_event(event.event_id)
+                invited.append(InvitedSyncResult(
+                    room_id=event.room_id,
+                    invite=invite,
+                ))
+            elif event.membership in (Membership.LEAVE, Membership.BAN):
+                # Always send down rooms we were banned or kicked from.
+                if not sync_config.filter_collection.include_leave:
+                    if event.membership == Membership.LEAVE:
+                        if user_id == event.sender:
+                            return
 
-        room_list = filter(_should_include_room, room_list)
+                leave_token = now_token.copy_and_replace(
+                    "room_key", "s%d" % (event.stream_ordering,)
+                )
+                room_result = yield self.full_state_sync_for_archived_room(
+                    sync_config=sync_config,
+                    room_id=event.room_id,
+                    leave_event_id=event.event_id,
+                    leave_token=leave_token,
+                    timeline_since_token=timeline_since_token,
+                    tags_by_room=tags_by_room,
+                    account_data_by_room=account_data_by_room,
+                )
+                archived.append(room_result)
 
-        room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
-        for room_list_chunk in room_list_chunks:
-            for event in room_list_chunk:
-                if event.membership == Membership.JOIN:
-                    room_sync_deferred = preserve_fn(
-                        self.full_state_sync_for_joined_room
-                    )(
-                        room_id=event.room_id,
-                        sync_config=sync_config,
-                        now_token=now_token,
-                        timeline_since_token=timeline_since_token,
-                        ephemeral_by_room=ephemeral_by_room,
-                        tags_by_room=tags_by_room,
-                        account_data_by_room=account_data_by_room,
-                    )
-                    room_sync_deferred.addCallback(joined.append)
-                    deferreds.append(room_sync_deferred)
-                elif event.membership == Membership.INVITE:
-                    invite = yield self.store.get_event(event.event_id)
-                    invited.append(InvitedSyncResult(
-                        room_id=event.room_id,
-                        invite=invite,
-                    ))
-                elif event.membership in (Membership.LEAVE, Membership.BAN):
-                    leave_token = now_token.copy_and_replace(
-                        "room_key", "s%d" % (event.stream_ordering,)
-                    )
-                    room_sync_deferred = preserve_fn(
-                        self.full_state_sync_for_archived_room
-                    )(
-                        sync_config=sync_config,
-                        room_id=event.room_id,
-                        leave_event_id=event.event_id,
-                        leave_token=leave_token,
-                        timeline_since_token=timeline_since_token,
-                        tags_by_room=tags_by_room,
-                        account_data_by_room=account_data_by_room,
-                    )
-                    room_sync_deferred.addCallback(archived.append)
-                    deferreds.append(room_sync_deferred)
-
-            yield defer.gatherResults(
-                deferreds, consumeErrors=True
-            ).addErrback(unwrapFirstError)
+        yield concurrently_execute(_generate_room_entry, room_list, 10)
 
         account_data_for_user = sync_config.filter_collection.filter_account_data(
             self.account_data_for_user(account_data)
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 640fae389..a75e1c71f 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,7 +16,8 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import PreserveLoggingContext
+from .logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util import unwrapFirstError
 
 
 @defer.inlineCallbacks
@@ -107,3 +108,32 @@ class ObservableDeferred(object):
         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
             id(self), self._result, self._deferred,
         )
+
+
+def concurrently_execute(func, args, limit):
+    """Executes the function with each argument conncurrently while limiting
+    the number of concurrent executions.
+
+    Args:
+        func (func): Function to execute, should return a deferred.
+        args (list): List of arguments to pass to func, each invocation of func
+            gets a signle argument.
+        limit (int): Maximum number of conccurent executions.
+
+    Returns:
+        deferred
+    """
+    it = iter(args)
+
+    @defer.inlineCallbacks
+    def _concurrently_execute_inner():
+        try:
+            while True:
+                yield func(it.next())
+        except StopIteration:
+            pass
+
+    return defer.gatherResults([
+        preserve_fn(_concurrently_execute_inner)()
+        for _ in xrange(limit)
+    ], consumeErrors=True).addErrback(unwrapFirstError)