0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-29 07:58:19 +02:00

Convert appservice to async. (#7973)

This commit is contained in:
Patrick Cloke 2020-07-30 07:27:39 -04:00 committed by GitHub
parent b3a97d6dac
commit 4cce8ef74e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 122 additions and 103 deletions

1
changelog.d/7973.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -15,11 +15,9 @@
import logging import logging
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import GroupID, get_domain_from_id from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +41,7 @@ class AppServiceTransaction(object):
Args: Args:
as_api(ApplicationServiceApi): The API to use to send. as_api(ApplicationServiceApi): The API to use to send.
Returns: Returns:
A Deferred which resolves to True if the transaction was sent. An Awaitable which resolves to True if the transaction was sent.
""" """
return as_api.push_bulk( return as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id service=self.service, events=self.events, txn_id=self.id
@ -172,8 +170,7 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@defer.inlineCallbacks async def _matches_user(self, event, store):
def _matches_user(self, event, store):
if not event: if not event:
return False return False
@ -188,12 +185,12 @@ class ApplicationService(object):
if not store: if not store:
return False return False
does_match = yield self._matches_user_in_member_list(event.room_id, store) does_match = await self._matches_user_in_member_list(event.room_id, store)
return does_match return does_match
@cachedInlineCallbacks(num_args=1, cache_context=True) @cached(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context): async def _matches_user_in_member_list(self, room_id, store, cache_context):
member_list = yield store.get_users_in_room( member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate room_id, on_invalidate=cache_context.invalidate
) )
@ -208,35 +205,33 @@ class ApplicationService(object):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
@defer.inlineCallbacks async def _matches_aliases(self, event, store):
def _matches_aliases(self, event, store):
if not store or not event: if not store or not event:
return False return False
alias_list = yield store.get_aliases_for_room(event.room_id) alias_list = await store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
return True return True
return False return False
@defer.inlineCallbacks async def is_interested(self, event, store=None) -> bool:
def is_interested(self, event, store=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
store(DataStore) store(DataStore)
Returns: Returns:
bool: True if this service would like to know about this event. True if this service would like to know about this event.
""" """
# Do cheap checks first # Do cheap checks first
if self._matches_room_id(event): if self._matches_room_id(event):
return True return True
if (yield self._matches_aliases(event, store)): if await self._matches_aliases(event, store):
return True return True
if (yield self._matches_user(event, store)): if await self._matches_user(event, store):
return True return True
return False return False

View file

@ -93,13 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient):
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
) )
@defer.inlineCallbacks async def query_user(self, service, user_id):
def query_user(self, service, user_id):
if service.url is None: if service.url is None:
return False return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try: try:
response = yield self.get_json(uri, {"access_token": service.hs_token}) response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
except CodeMessageException as e: except CodeMessageException as e:
@ -110,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_user to %s threw exception %s", uri, ex) logger.warning("query_user to %s threw exception %s", uri, ex)
return False return False
@defer.inlineCallbacks async def query_alias(self, service, alias):
def query_alias(self, service, alias):
if service.url is None: if service.url is None:
return False return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try: try:
response = yield self.get_json(uri, {"access_token": service.hs_token}) response = await self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
except CodeMessageException as e: except CodeMessageException as e:
@ -128,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
return False return False
@defer.inlineCallbacks async def query_3pe(self, service, kind, protocol, fields):
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
@ -146,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol), urllib.parse.quote(protocol),
) )
try: try:
response = yield self.get_json(uri, fields) response = await self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
logger.warning( logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response "query_3pe to %s returned an invalid response %r", uri, response
@ -202,8 +198,7 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get) return self.protocol_meta_cache.wrap(key, _get)
@defer.inlineCallbacks async def push_bulk(self, service, events, txn_id=None):
def push_bulk(self, service, events, txn_id=None):
if service.url is None: if service.url is None:
return True return True
@ -218,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try: try:
yield self.put_json( await self.put_json(
uri=uri, uri=uri,
json_body={"events": events}, json_body={"events": events},
args={"access_token": service.hs_token}, args={"access_token": service.hs_token},

View file

@ -50,8 +50,6 @@ components.
""" """
import logging import logging
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks async def start(self):
def start(self):
logger.info("Starting appservice scheduler") logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them. # check for any DOWN ASes and start recoverers for them.
services = yield self.store.get_appservices_by_state( services = await self.store.get_appservices_by_state(
ApplicationServiceState.DOWN ApplicationServiceState.DOWN
) )
@ -117,8 +114,7 @@ class _ServiceQueuer(object):
"as-sender-%s" % (service.id,), self._send_request, service "as-sender-%s" % (service.id,), self._send_request, service
) )
@defer.inlineCallbacks async def _send_request(self, service):
def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender # sanity-check: we shouldn't get here if this service already has a sender
# running. # running.
assert service.id not in self.requests_in_flight assert service.id not in self.requests_in_flight
@ -130,7 +126,7 @@ class _ServiceQueuer(object):
if not events: if not events:
return return
try: try:
yield self.txn_ctrl.send(service, events) await self.txn_ctrl.send(service, events)
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
finally: finally:
@ -162,36 +158,33 @@ class _TransactionController(object):
# for UTs # for UTs
self.RECOVERER_CLASS = _Recoverer self.RECOVERER_CLASS = _Recoverer
@defer.inlineCallbacks async def send(self, service, events):
def send(self, service, events):
try: try:
txn = yield self.store.create_appservice_txn(service=service, events=events) txn = await self.store.create_appservice_txn(service=service, events=events)
service_is_up = yield self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = await txn.send(self.as_api)
if sent: if sent:
yield txn.complete(self.store) await txn.complete(self.store)
else: else:
run_in_background(self._on_txn_fail, service) run_in_background(self._on_txn_fail, service)
except Exception: except Exception:
logger.exception("Error creating appservice transaction") logger.exception("Error creating appservice transaction")
run_in_background(self._on_txn_fail, service) run_in_background(self._on_txn_fail, service)
@defer.inlineCallbacks async def on_recovered(self, recoverer):
def on_recovered(self, recoverer):
logger.info( logger.info(
"Successfully recovered application service AS ID %s", recoverer.service.id "Successfully recovered application service AS ID %s", recoverer.service.id
) )
self.recoverers.pop(recoverer.service.id) self.recoverers.pop(recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers)) logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state( await self.store.set_appservice_state(
recoverer.service, ApplicationServiceState.UP recoverer.service, ApplicationServiceState.UP
) )
@defer.inlineCallbacks async def _on_txn_fail(self, service):
def _on_txn_fail(self, service):
try: try:
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
self.start_recoverer(service) self.start_recoverer(service)
except Exception: except Exception:
logger.exception("Error starting AS recoverer") logger.exception("Error starting AS recoverer")
@ -211,9 +204,8 @@ class _TransactionController(object):
recoverer.recover() recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers)) logger.info("Now %i active recoverers", len(self.recoverers))
@defer.inlineCallbacks async def _is_service_up(self, service):
def _is_service_up(self, service): state = await self.store.get_appservice_state(service)
state = yield self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None return state == ApplicationServiceState.UP or state is None
@ -254,25 +246,24 @@ class _Recoverer(object):
self.backoff_counter += 1 self.backoff_counter += 1
self.recover() self.recover()
@defer.inlineCallbacks async def retry(self):
def retry(self):
logger.info("Starting retries on %s", self.service.id) logger.info("Starting retries on %s", self.service.id)
try: try:
while True: while True:
txn = yield self.store.get_oldest_unsent_txn(self.service) txn = await self.store.get_oldest_unsent_txn(self.service)
if not txn: if not txn:
# nothing left: we're done! # nothing left: we're done!
self.callback(self) await self.callback(self)
return return
logger.info( logger.info(
"Retrying transaction %s for AS ID %s", txn.id, txn.service.id "Retrying transaction %s for AS ID %s", txn.id, txn.service.id
) )
sent = yield txn.send(self.as_api) sent = await txn.send(self.as_api)
if not sent: if not sent:
break break
yield txn.complete(self.store) await txn.complete(self.store)
# reset the backoff counter and then process the next transaction # reset the backoff counter and then process the next transaction
self.backoff_counter = 1 self.backoff_counter = 1

View file

@ -27,7 +27,6 @@ from synapse.metrics import (
event_processing_loop_room_count, event_processing_loop_room_count,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import log_failure
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -100,10 +99,11 @@ class ApplicationServicesHandler(object):
if not self.started_scheduler: if not self.started_scheduler:
def start_scheduler(): async def start_scheduler():
return self.scheduler.start().addErrback( try:
log_failure, "Application Services Failure" return self.scheduler.start()
) except Exception:
logger.error("Application Services Failure")
run_as_background_process("as_scheduler", start_scheduler) run_as_background_process("as_scheduler", start_scheduler)
self.started_scheduler = True self.started_scheduler = True

View file

@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("!some_prefix.*some_suffix:matrix.org") _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = defer.succeed(
"#irc_foobar:matrix.org", ["#irc_foobar:matrix.org", "#athing:matrix.org"]
"#athing:matrix.org", )
] self.store.get_users_in_room.return_value = defer.succeed([])
self.store.get_users_in_room.return_value = [] self.assertTrue(
self.assertTrue((yield self.service.is_interested(self.event, self.store))) (
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = defer.succeed(
"#xmpp_foobar:matrix.org", ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
"#athing:matrix.org", )
] self.store.get_users_in_room.return_value = defer.succeed([])
self.store.get_users_in_room.return_value = [] self.assertFalse(
self.assertFalse((yield self.service.is_interested(self.event, self.store))) (
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] self.store.get_aliases_for_room.return_value = defer.succeed(
self.store.get_users_in_room.return_value = [] ["#irc_barfoo:matrix.org"]
self.assertTrue((yield self.service.is_interested(self.event, self.store))) )
self.store.get_users_in_room.return_value = defer.succeed([])
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.content = {"membership": "invite"} self.event.content = {"membership": "invite"}
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue(
(yield defer.ensureDeferred(self.service.is_interested(self.event)))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.store.get_users_in_room.return_value = [ # Note that @irc_fo:here is the AS user.
"@alice:here", self.store.get_users_in_room.return_value = defer.succeed(
"@irc_fo:here", # AS user ["@alice:here", "@irc_fo:here", "@bob:here"]
"@bob:here", )
] self.store.get_aliases_for_room.return_value = defer.succeed([])
self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue( self.assertTrue(
(yield self.service.is_interested(event=self.event, store=self.store)) (
yield defer.ensureDeferred(
self.service.is_interested(event=self.event, store=self.store)
)
)
) )

View file

@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from ..utils import MockClock from ..utils import MockClock
@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock( self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
txn.send = Mock(return_value=defer.succeed(True)) txn.send = Mock(return_value=make_awaitable(True))
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events # txn made and saved
@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events # txn made and saved
@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
txn.send = Mock(return_value=defer.succeed(False)) # fails to send txn.send = Mock(return_value=make_awaitable(False)) # fails to send
self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
# actual call # actual call
self.txnctrl.send(service, events) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events service=service, events=events
@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff # shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=True) txn.send = Mock(return_value=make_awaitable(True))
txn.complete.return_value = make_awaitable(None)
# wait for exp backoff # wait for exp backoff
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count) self.assertEquals(1, txn.send.call_count)
@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=False) txn.send = Mock(return_value=make_awaitable(False))
txn.complete.return_value = make_awaitable(None)
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count) self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, txn.complete.call_count)
@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEquals(3, txn.send.call_count) self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count) self.assertEquals(0, self.callback.call_count)
txn.send = Mock(return_value=True) # successfully send the txn txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more. pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16) self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count self.assertEquals(1, txn.send.call_count) # new mock reset call count

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from tests.test_utils import make_awaitable
from tests.utils import MockClock from tests.utils import MockClock
from .. import unittest from .. import unittest
@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_interested_in_alias=False), self._mkservice_alias(is_interested_in_alias=False),
] ]
self.mock_as_api.query_alias.return_value = defer.succeed(True) self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_association_from_room_alias.return_value = defer.succeed( self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
Mock(room_id=room_id, servers=servers) Mock(room_id=room_id, servers=servers)
@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def _mkservice(self, is_interested): def _mkservice(self, is_interested):
service = Mock() service = Mock()
service.is_interested.return_value = defer.succeed(is_interested) service.is_interested.return_value = make_awaitable(is_interested)
service.token = "mock_service_token" service.token = "mock_service_token"
service.url = "mock_service_url" service.url = "mock_service_url"
return service return service