diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 925eb5376..692c2d8a7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -285,7 +285,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_event_auth(self, event_id): - auth = yield self.store.get_auth_chain(event_id) + auth = yield self.store.get_auth_chain([event_id]) for event in auth: event.signatures.update( @@ -494,7 +494,10 @@ class FederationHandler(BaseHandler): yield self.replication_layer.send_pdu(new_pdu) - auth_chain = yield self.store.get_auth_chain(event.event_id) + state_ids = [e.event_id for e in event.state_events.values()] + auth_chain = yield self.store.get_auth_chain(set( + [event.event_id] + state_ids + )) defer.returnValue({ "state": event.state_events.values(), diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 6c559f8f6..0ff9a23ee 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ - def get_auth_chain(self, event_id): + def get_auth_chain(self, event_ids): return self.runInteraction( "get_auth_chain", self._get_auth_chain_txn, - event_id + event_ids ) - def _get_auth_chain_txn(self, txn, event_id): - results = self._get_auth_chain_ids_txn(txn, event_id) + def _get_auth_chain_txn(self, txn, event_ids): + results = self._get_auth_chain_ids_txn(txn, event_ids) sql = "SELECT * FROM events WHERE event_id = ?" rows = [] @@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore): return self._parse_events_txn(txn, rows) - def get_auth_chain_ids(self, event_id): + def get_auth_chain_ids(self, event_ids): return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_id + event_ids ) - def _get_auth_chain_ids_txn(self, txn, event_id): + def _get_auth_chain_ids_txn(self, txn, event_ids): results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE %s" ) - front = set([event_id]) + front = set(event_ids) while front: sql = base_sql % ( " OR ".join(["event_id=?"] * len(front)),