forked from MirrorHub/synapse
Merge branch 'release-v0.6.0' of github.com:matrix-org/synapse into erikj-perf
This commit is contained in:
commit
41ce544abe
5 changed files with 80 additions and 45 deletions
|
@ -256,31 +256,35 @@ class ReplicationLayer(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_state_for_context(self, destination, context, event_id=None):
|
||||
def get_state_for_context(self, destination, context, event_id):
|
||||
"""Requests all of the `current` state PDUs for a given context from
|
||||
a remote home server.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
context (str): The context we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_context_state(
|
||||
result = yield self.transport_layer.get_context_state(
|
||||
destination,
|
||||
context,
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
pdus = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in transaction.pdus
|
||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
defer.returnValue(pdus)
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
defer.returnValue((pdus, auth_chain))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -383,10 +387,16 @@ class ReplicationLayer(object):
|
|||
context,
|
||||
event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
defer.returnValue((200, {
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -573,6 +583,8 @@ class ReplicationLayer(object):
|
|||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
|
||||
# We need to make sure we have all the auth events.
|
||||
# for e_id, _ in pdu.auth_events:
|
||||
# exists = yield self._get_persisted_pdu(
|
||||
|
@ -645,7 +657,7 @@ class ReplicationLayer(object):
|
|||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
state = yield self.get_state_for_context(
|
||||
state, auth_chain = yield self.get_state_for_context(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
|
||||
|
@ -655,6 +667,7 @@ class ReplicationLayer(object):
|
|||
pdu,
|
||||
backfilled=backfilled,
|
||||
state=state,
|
||||
auth_chain=auth_chain,
|
||||
)
|
||||
else:
|
||||
ret = None
|
||||
|
|
|
@ -95,7 +95,8 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None):
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None,
|
||||
auth_chain=None):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it through the StateHandler.
|
||||
"""
|
||||
|
@ -150,35 +151,35 @@ class FederationHandler(BaseHandler):
|
|||
if not is_in_room and not event.internal_metadata.outlier:
|
||||
logger.debug("Got event for room we're not in.")
|
||||
|
||||
replication_layer = self.replication_layer
|
||||
auth_chain = yield replication_layer.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
replication = self.replication_layer
|
||||
|
||||
if not state:
|
||||
state, auth_chain = yield replication.get_state_for_context(
|
||||
origin, context=event.room_id, event_id=event.event_id,
|
||||
)
|
||||
|
||||
if not auth_chain:
|
||||
auth_chain = yield replication.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
|
||||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
yield self._handle_new_event(e, fetch_auth_from=origin)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
if not state:
|
||||
state = yield replication_layer.get_state_for_context(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
# FIXME: Get auth chain for these state events
|
||||
|
||||
current_state = state
|
||||
|
||||
if state:
|
||||
for e in state:
|
||||
logging.info("A :) %r", e)
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e)
|
||||
|
@ -392,7 +393,7 @@ class FederationHandler(BaseHandler):
|
|||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
yield self._handle_new_event(e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
|
@ -404,8 +405,7 @@ class FederationHandler(BaseHandler):
|
|||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(
|
||||
e,
|
||||
fetch_missing=True
|
||||
e, fetch_auth_from=target_host
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
|
@ -682,7 +682,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None, fetch_missing=True):
|
||||
current_state=None, fetch_auth_from=None):
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||
|
@ -703,11 +703,20 @@ class FederationHandler(BaseHandler):
|
|||
known_ids = set(
|
||||
[s.event_id for s in context.auth_events.values()]
|
||||
)
|
||||
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(
|
||||
e_id, allow_none=True,
|
||||
)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e and fetch_auth_from is not None:
|
||||
# Grab the auth_chain over federation if we are missing
|
||||
# auth events.
|
||||
auth_chain = yield self.replication_layer.get_event_auth(
|
||||
fetch_auth_from, event.event_id, event.room_id
|
||||
)
|
||||
for auth_event in auth_chain:
|
||||
yield self._handle_new_event(auth_event)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e:
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
|
|
|
@ -115,10 +115,10 @@ class Signal(object):
|
|||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
if not self.suppress_failures:
|
||||
raise failure
|
||||
failure.raiseException()
|
||||
deferreds.append(d.addErrback(eb))
|
||||
|
||||
result = yield defer.DeferredList(
|
||||
deferreds, fireOnOneErrback=not self.suppress_failures
|
||||
)
|
||||
defer.returnValue(result)
|
||||
results = []
|
||||
for deferred in deferreds:
|
||||
result = yield deferred
|
||||
results.append(result)
|
||||
defer.returnValue(results)
|
||||
|
|
|
@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_auth_chain",
|
||||
])
|
||||
self.mock_persistence.get_received_txn_response.return_value = (
|
||||
defer.succeed(None)
|
||||
|
@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
self.mock_persistence.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
self.mock_persistence.get_auth_chain.return_value = []
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
self.clock = MockClock()
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests import unittest
|
||||
from . import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
|
||||
class DistributorTestCase(unittest.TestCase):
|
||||
|
@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.dist = Distributor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch(self):
|
||||
self.dist.declare("alert")
|
||||
|
||||
|
@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
|
|||
self.dist.observe("alert", observer)
|
||||
|
||||
d = self.dist.fire("alert", 1, 2, 3)
|
||||
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
observer.assert_called_with(1, 2, 3)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch_deferred(self):
|
||||
self.dist.declare("whine")
|
||||
|
||||
|
@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
|
|||
self.assertFalse(d_outer.called)
|
||||
|
||||
d_inner.callback(None)
|
||||
yield d_outer
|
||||
self.assertTrue(d_outer.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch(self):
|
||||
self.dist.declare("alarm")
|
||||
|
||||
|
@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
|
|||
spec=["warning"]
|
||||
) as mock_logger:
|
||||
d = self.dist.fire("alarm", "Go")
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
|
||||
observers[0].assert_called_once("Go")
|
||||
|
@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
|
|||
|
||||
self.dist.declare("whail")
|
||||
|
||||
observer = Mock()
|
||||
observer.return_value = defer.fail(
|
||||
Exception("Oopsie")
|
||||
)
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def observer():
|
||||
yield run_on_reactor()
|
||||
raise MyException("Oopsie")
|
||||
|
||||
self.dist.observe("whail", observer)
|
||||
|
||||
d = self.dist.fire("whail")
|
||||
|
||||
yield self.assertFailure(d, Exception)
|
||||
yield self.assertFailure(d, MyException)
|
||||
self.dist.suppress_failures = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_prereg(self):
|
||||
observer = Mock()
|
||||
self.dist.observe("flare", observer)
|
||||
|
||||
self.dist.declare("flare")
|
||||
self.dist.fire("flare", 4, 5)
|
||||
yield self.dist.fire("flare", 4, 5)
|
||||
|
||||
observer.assert_called_with(4, 5)
|
||||
|
||||
|
|
Loading…
Reference in a new issue