Add '/event_auth/' federation api

This commit is contained in:
Erik Johnston 2014-11-07 15:35:53 +00:00
parent d2fb2b8095
commit 02c3b1c9e2
4 changed files with 55 additions and 7 deletions

View file

@ -426,6 +426,11 @@ class ReplicationLayer(object):
"auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, context, event_id):
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, [a.get_dict() for a in auth_pdus]))
@defer.inlineCallbacks
def make_join(self, destination, context, user_id):
pdu_dict = yield self.transport_layer.make_join(

View file

@ -256,6 +256,21 @@ class TransportLayer(object):
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, context, event_id):
path = PREFIX + "/event_auth/%s/%s" % (
context,
event_id,
)
response = yield self.client.get_json(
destination=destination,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
@ -426,6 +441,17 @@ class TransportLayer(object):
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
handler.on_event_auth(
origin, context, event_id,
)
)
)
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),

View file

@ -224,6 +224,11 @@ class FederationHandler(BaseHandler):
defer.returnValue(self.pdu_codec.event_from_pdu(pdu))
@defer.inlineCallbacks
def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain(event_id)
defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth])
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):

View file

@ -32,6 +32,24 @@ class EventFederationStore(SQLBaseStore):
)
def _get_auth_chain_txn(self, txn, event_id):
results = self._get_auth_chain_ids_txn(txn, event_id)
sql = "SELECT * FROM events WHERE event_id = ?"
rows = []
for ev_id in results:
c = txn.execute(sql, (ev_id,))
rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, rows)
def get_auth_chain_ids(self, event_id):
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_id
)
def _get_auth_chain_ids_txn(self, txn, event_id):
results = set()
base_sql = (
@ -48,13 +66,7 @@ class EventFederationStore(SQLBaseStore):
front = [r[0] for r in txn.fetchall()]
results.update(front)
sql = "SELECT * FROM events WHERE event_id = ?"
rows = []
for ev_id in results:
c = txn.execute(sql, (ev_id,))
rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, rows)
return list(results)
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(