0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Convert push to async/await. (#7948)

This commit is contained in:
Patrick Cloke 2020-07-27 12:21:34 -04:00 committed by GitHub
parent 7c2e2c2077
commit 8144bc26a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 106 additions and 145 deletions

1
changelog.d/7948.misc Normal file
View file

@ -0,0 +1 @@
Convert push to async/await.

View file

@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
@ -37,7 +35,6 @@ class ActionGenerator(object):
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
async def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
yield self.bulk_evaluator.action_for_event_by_user(event, context)
await self.bulk_evaluator.action_for_event_by_user(event, context)

View file

@ -19,8 +19,6 @@ from collections import namedtuple
from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY
@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object):
resizable=False,
)
@defer.inlineCallbacks
def _get_rules_for_event(self, event, context):
async def _get_rules_for_event(self, event, context):
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = yield self._get_rules_for_room(room_id)
rules_for_room = await self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
rules_by_user = await rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited)
has_pusher = await self.store.user_has_pusher(invited)
if has_pusher:
rules_by_user = dict(rules_by_user)
rules_by_user[invited] = yield self.store.get_push_rules_for_user(
rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited
)
@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics,
)
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = yield context.get_prev_state_ids()
async def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id)
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table.
Returns:
Deferred
"""
rules_by_user = yield self._get_rules_for_event(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
room_members = yield self.store.get_joined_users_from_context(event, context)
room_members = await self.store.get_joined_users_from_context(event, context)
(
power_levels,
sender_power_level,
) = yield self._get_power_levels_and_sender_level(event, context)
) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object):
continue
if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid)
is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored:
continue
@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@ -274,8 +266,7 @@ class RulesForRoom(object):
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks
def get_rules(self, event, context):
async def get_rules(self, event, context):
"""Given an event context return the rules for all users who are
currently in the room.
"""
@ -286,7 +277,7 @@ class RulesForRoom(object):
self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user
with (yield self.linearizer.queue(())):
with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
@ -304,9 +295,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
@ -353,7 +342,7 @@ class RulesForRoom(object):
# If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids(
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
@ -371,8 +360,7 @@ class RulesForRoom(object):
)
return ret_rules_by_user
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(
async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event
):
"""Update the partially filled rules_by_user dict by fetching rules for
@ -388,7 +376,7 @@ class RulesForRoom(object):
"""
sequence = self.sequence
rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
@ -410,7 +398,7 @@ class RulesForRoom(object):
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
@ -420,7 +408,7 @@ class RulesForRoom(object):
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
@ -431,7 +419,7 @@ class RulesForRoom(object):
if uid in interested_in_user_ids:
user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules(
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)

View file

@ -17,7 +17,6 @@ import logging
from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
@ -128,12 +127,11 @@ class HttpPusher(object):
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
@defer.inlineCallbacks
def _update_badge(self):
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
await self._send_badge(badge)
def on_timer(self):
self._start_processing()
@ -152,8 +150,7 @@ class HttpPusher(object):
run_as_background_process("httppush.process", self._process)
@defer.inlineCallbacks
def _process(self):
async def _process(self):
# we should never get here if we are already processing
assert not self._is_processing
@ -164,7 +161,7 @@ class HttpPusher(object):
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
await self._unsafe_process()
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
@ -172,8 +169,7 @@ class HttpPusher(object):
finally:
self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
async def _unsafe_process(self):
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
@ -181,7 +177,7 @@ class HttpPusher(object):
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@ -203,13 +199,13 @@ class HttpPusher(object):
"app_display_name": self.app_display_name,
},
):
processed = yield self._process_one(push_action)
processed = await self._process_one(push_action)
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
self.user_id,
@ -224,14 +220,14 @@ class HttpPusher(object):
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
@ -250,7 +246,7 @@ class HttpPusher(object):
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering(
pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
@ -263,7 +259,7 @@ class HttpPusher(object):
return
self.failing_since = None
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
@ -276,18 +272,17 @@ class HttpPusher(object):
)
break
@defer.inlineCallbacks
def _process_one(self, push_action):
async def _process_one(self, push_action):
if "notify" not in push_action["actions"]:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True)
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
rejected = await self.dispatch_push(event, tweaks, badge)
if rejected is False:
return False
@ -301,11 +296,10 @@ class HttpPusher(object):
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
async def _build_notification_dict(self, event, tweaks, badge):
priority = "low"
if (
event.type == EventTypes.Encrypted
@ -335,7 +329,7 @@ class HttpPusher(object):
}
return d
ctx = yield push_tools.get_context_for_event(
ctx = await push_tools.get_context_for_event(
self.storage, self.state_handler, event, self.user_id
)
@ -377,13 +371,12 @@ class HttpPusher(object):
return d
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
async def dispatch_push(self, event, tweaks, badge):
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []
try:
resp = yield self.http_client.post_json_get_json(
resp = await self.http_client.post_json_get_json(
self.url, notification_dict
)
except Exception as e:
@ -400,8 +393,7 @@ class HttpPusher(object):
rejected = resp["rejected"]
return rejected
@defer.inlineCallbacks
def _send_badge(self, badge):
async def _send_badge(self, badge):
"""
Args:
badge (int): number of unread messages
@ -424,7 +416,7 @@ class HttpPusher(object):
}
}
try:
yield self.http_client.post_json_get_json(self.url, d)
await self.http_client.post_json_get_json(self.url, d)
http_badges_processed_counter.inc()
except Exception as e:
logger.warning(

View file

@ -16,8 +16,6 @@
import logging
import re
from twisted.internet import defer
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__)
@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
@defer.inlineCallbacks
def calculate_room_name(
async def calculate_room_name(
store,
room_state_ids,
user_id,
@ -53,7 +50,7 @@ def calculate_room_name(
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event(
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
@ -61,7 +58,7 @@ def calculate_room_name(
# does it have a canonical alias?
if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event(
canon_alias = await store.get_event(
room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
)
if (
@ -81,7 +78,7 @@ def calculate_room_name(
my_member_event = None
if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event(
my_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, user_id)], allow_none=True
)
@ -90,7 +87,7 @@ def calculate_room_name(
and my_member_event.content["membership"] == "invite"
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event(
inviter_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True,
)
@ -107,7 +104,7 @@ def calculate_room_name(
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events(
member_events = await store.get_events(
list(room_state_bytype_ids[EventTypes.Member].values())
)
all_members = [

View file

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
async def get_badge_count(store, user_id):
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
badge = len(invites)
@ -32,7 +29,7 @@ def get_badge_count(store, user_id):
if room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield (
notifs = await (
store.get_unread_event_push_actions_by_room_for_user(
room_id, user_id, last_unread_event_id
)
@ -43,23 +40,22 @@ def get_badge_count(store, user_id):
return badge
@defer.inlineCallbacks
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {}
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = yield calculate_room_name(
name = await calculate_room_name(
storage.main, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield storage.main.get_event(sender_state_event_id)
sender_state_event = await storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx

View file

@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@ -52,7 +50,7 @@ class PusherPool:
Note that it is expected that each pusher will have its own 'processing' loop which
will send out the notifications in the background, rather than blocking until the
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds.
Pusher.on_new_receipts are not expected to return awaitables.
"""
def __init__(self, hs: "HomeServer"):
@ -77,8 +75,7 @@ class PusherPool:
return
run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks
def add_pusher(
async def add_pusher(
self,
user_id,
access_token,
@ -94,7 +91,7 @@ class PusherPool:
"""Creates a new pusher and adds it to the pool
Returns:
Deferred[EmailPusher|HttpPusher]
EmailPusher|HttpPusher
"""
time_now_msec = self.clock.time_msec()
@ -124,9 +121,9 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
kind=kind,
@ -140,15 +137,14 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
return pusher
@defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(
async def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id
):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
if p["user_name"] != not_user_id:
logger.info(
@ -157,10 +153,9 @@ class PusherPool:
pushkey,
p["user_name"],
)
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def remove_pushers_by_access_token(self, user_id, access_tokens):
async def remove_pushers_by_access_token(self, user_id, access_tokens):
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
@ -173,7 +168,7 @@ class PusherPool:
return
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
@ -181,16 +176,15 @@ class PusherPool:
p["pushkey"],
p["user_name"],
)
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
async def on_new_notifications(self, min_stream_id, max_stream_id):
if not self.pushers:
# nothing to do here.
return
try:
users_affected = yield self.store.get_push_action_users_in_range(
users_affected = await self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
)
@ -202,8 +196,7 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
if not self.pushers:
# nothing to do here.
return
@ -211,7 +204,7 @@ class PusherPool:
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
users_affected = yield self.store.get_users_sent_receipts_between(
users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)
@ -223,12 +216,11 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
def start_pusher_by_id(self, app_id, pushkey, user_id):
async def start_pusher_by_id(self, app_id, pushkey, user_id):
"""Look up the details for the given pusher, and start it
Returns:
Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
EmailPusher|HttpPusher|None: The pusher started, if any
"""
if not self._should_start_pushers:
return
@ -236,7 +228,7 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
for r in resultlist:
@ -245,34 +237,29 @@ class PusherPool:
pusher = None
if pusher_dict:
pusher = yield self._start_pusher(pusher_dict)
pusher = await self._start_pusher(pusher_dict)
return pusher
@defer.inlineCallbacks
def _start_pushers(self):
async def _start_pushers(self) -> None:
"""Start all the pushers
Returns:
Deferred
"""
pushers = yield self.store.get_all_pushers()
pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
yield concurrently_execute(self._start_pusher, pushers, 10)
await concurrently_execute(self._start_pusher, pushers, 10)
logger.info("Started pushers")
@defer.inlineCallbacks
def _start_pusher(self, pusherdict):
async def _start_pusher(self, pusherdict):
"""Start the given pusher
Args:
pusherdict (dict): dict with the values pulled from the db table
Returns:
Deferred[EmailPusher|HttpPusher]
EmailPusher|HttpPusher
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
@ -315,7 +302,7 @@ class PusherPool:
user_id = pusherdict["user_name"]
last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
)
else:
@ -327,8 +314,7 @@ class PusherPool:
return p
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
async def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
@ -340,6 +326,6 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)

View file

@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn,
)
def add_push_actions_to_staging(self, event_id, user_id_actions):
async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
return self.db.runInteraction(
return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)

View file

@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
self.get_success(
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
)
return event, context

View file

@ -72,9 +72,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
yield self.store.add_push_actions_to_staging(
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
)
)
yield self.store.db.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,