From 62c010283d543db0956066b42eb735b57c000a82 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 23 Jul 2015 16:03:38 +0100 Subject: [PATCH 1/3] Add federation support for end-to-end key requests --- synapse/federation/federation_client.py | 34 ++++++++ synapse/federation/federation_server.py | 37 +++++++++ synapse/federation/transport/client.py | 70 +++++++++++++++++ synapse/federation/transport/server.py | 20 +++++ synapse/rest/client/v2_alpha/keys.py | 100 +++++++++++++++++------- 5 files changed, 231 insertions(+), 30 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7736d14fb..21a86a4c6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -134,6 +134,40 @@ class FederationClient(FederationBase): destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail ) + @log_function + def query_client_keys(self, destination, content, retry_on_dns_fail=True): + """Query device keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_device_keys") + return self.transport_layer.query_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + + @log_function + def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + """Claims one-time keys for a device hosted on a remote server. + + Args: + destination (str): Domain name of the remote homeserver + content (dict): The query content. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + sent_queries_counter.inc("client_one_time_keys") + return self.transport_layer.claim_client_keys( + destination, content, retry_on_dns_fail=retry_on_dns_fail + ) + @defer.inlineCallbacks @log_function def backfill(self, dest, context, limit, extremities): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cd79e23f4..c32908ac2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature +import simplejson as json import logging @@ -312,6 +313,42 @@ class FederationServer(FederationBase): (200, send_content) ) + @defer.inlineCallbacks + @log_function + def on_query_client_keys(self, origin, content): + query = [] + for user_id, device_ids in content.get("device_keys", {}).items(): + if not device_ids: + query.append((user_id, None)) + else: + for device_id in device_ids: + query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[device_id] = json.loads( + json_bytes + ) + defer.returnValue({"device_keys": json_result}) + + @defer.inlineCallbacks + @log_function + def on_claim_client_keys(self, origin, content): + query = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + defer.returnValue({"one_time_keys": json_result}) + @defer.inlineCallbacks @log_function def on_get_missing_events(self, origin, room_id, earliest_events, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 610a4c316..df5083dd2 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -222,6 +222,76 @@ class TransportLayerClient(object): defer.returnValue(content) + @defer.inlineCallbacks + @log_function + def query_client_keys(self, destination, query_content): + """Query the device keys for a list of user ids hosted on a remote + server. + + Request: + { + "device_keys": { + "": [""] + } } + + Response: + { + "device_keys": { + "": { + "": {...} + } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the device keys. + """ + path = PREFIX + "/client_keys/query" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def claim_client_keys(self, destination, query_content): + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "": { + "": "" + } } } + + Response: + { + "device_keys": { + "": { + "": { + ":": "" + } } } } + + Args: + destination(str): The server to query. + query_content(dict): The user ids to query. + Returns: + A dict containg the one-time keys. + """ + + path = PREFIX + "/client_keys/claim" + + content = yield self.client.post_json( + destination=destination, + path=path, + data=query_content, + ) + defer.returnValue(content) + @defer.inlineCallbacks @log_function def get_missing_events(self, destination, room_id, earliest_events, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bad93c6b2..fb59383ec 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet): defer.returnValue((200, content)) +class FederationClientKeysQueryServlet(BaseFederationServlet): + PATH = "/client_keys/query" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_query(origin, content) + defer.returnValue((200, response)) + + +class FederationClientKeysClaimServlet(BaseFederationServlet): + PATH = "/client_keys/claim" + + @defer.inlineCallbacks + def on_POST(self, origin, content): + response = yield self.handler.on_client_key_claim(origin, content) + defer.returnValue((200, response)) + + class FederationQueryAuthServlet(BaseFederationServlet): PATH = "/query_auth/([^/]*)/([^/]*)" @@ -373,4 +391,6 @@ SERVLET_CLASSES = ( FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationClientKeysClaimServlet, ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5f3a6207b..739a08ada 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet +from synapse.types import UserID from syutil.jsonutil import encode_canonical_json from ._base import client_v2_pattern @@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet): super(KeyQueryServlet, self).__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): - logger.debug("onPOST") yield self.auth.get_user_by_req(request) try: body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_ids in body.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - results = yield self.store.get_e2e_device_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() - if not user_id: - user_id = auth_user_id - if not device_id: - device_id = None - # Returns a map of user_id->device_id->json_bytes. - results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) - defer.returnValue(self.json_result(request, results)) + user_id = user_id if user_id else auth_user_id + device_ids = [device_id] if device_id else [] + result = yield self.handle_request( + {"device_keys": {user_id: device_ids}} + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_ids in body.get("device_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + else: + remote_queries.set_default(user.domain, {})[user_id] = list( + device_ids + ) + results = yield self.store.get_e2e_device_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) - return (200, {"device_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"device_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + defer.returnValue((200, {"device_keys": json_result})) class OneTimeKeyServlet(RestServlet): @@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - results = yield self.store.claim_e2e_one_time_keys( - [(user_id, device_id, algorithm)] + result = yield self.handle_request( + {"one_time_keys": {user_id: {device_id: algorithm}}} ) - defer.returnValue(self.json_result(request, results)) + defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): @@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet): body = json.loads(request.content.read()) except: raise SynapseError(400, "Invalid key JSON") - query = [] - for user_id, device_keys in body.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) - results = yield self.store.claim_e2e_one_time_keys(query) - defer.returnValue(self.json_result(request, results)) + result = yield self.handle_request(body) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_request(self, body): + local_query = [] + remote_queries = {} + for user_id, device_keys in body.get("one_time_keys", {}).items(): + user = UserID.from_string(user_id) + if self.is_mine(user): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + remote_queries.set_default(user.domain, {})[user_id] = ( + device_keys + ) + results = yield self.store.claim_e2e_one_time_keys(local_query) - def json_result(self, request, results): json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } - return (200, {"one_time_keys": json_result}) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"one_time_keys": device_keys} + ) + for user_id, keys in remote_result.items(): + if user_id in device_keys: + json_result[user_id] = keys + + defer.returnValue((200, {"one_time_keys": json_result})) def register_servlets(hs, http_server): From 2da3b1e60bf7e9ae1d6714abcff0a0c224cadf28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 24 Jul 2015 18:26:46 +0100 Subject: [PATCH 2/3] Get the end-to-end key federation working --- synapse/federation/federation_client.py | 12 ++++-------- synapse/federation/transport/client.py | 4 ++-- synapse/federation/transport/server.py | 12 ++++++------ synapse/rest/client/v2_alpha/keys.py | 10 +++++----- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 21a86a4c6..44e4d0755 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content, retry_on_dns_fail=True): + def query_client_keys(self, destination, content): """Query device keys for a device hosted on a remote server. Args: @@ -147,12 +147,10 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.query_client_keys(destination, content) @log_function - def claim_client_keys(self, destination, content, retry_on_dns_fail=True): + def claim_client_keys(self, destination, content): """Claims one-time keys for a device hosted on a remote server. Args: @@ -164,9 +162,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys( - destination, content, retry_on_dns_fail=retry_on_dns_fail - ) + return self.transport_layer.claim_client_keys(destination, content) @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index df5083dd2..ced703364 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -247,7 +247,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/client_keys/query" + path = PREFIX + "/user/keys/query" content = yield self.client.post_json( destination=destination, @@ -283,7 +283,7 @@ class TransportLayerClient(object): A dict containg the one-time keys. """ - path = PREFIX + "/client_keys/claim" + path = PREFIX + "/user/keys/claim" content = yield self.client.post_json( destination=destination, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fb59383ec..36f250e1a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -326,20 +326,20 @@ class FederationInviteServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): - PATH = "/client_keys/query" + PATH = "/user/keys/query" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_query(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_query_client_keys(origin, content) defer.returnValue((200, response)) class FederationClientKeysClaimServlet(BaseFederationServlet): - PATH = "/client_keys/claim" + PATH = "/user/keys/claim" @defer.inlineCallbacks - def on_POST(self, origin, content): - response = yield self.handler.on_client_key_claim(origin, content) + def on_POST(self, origin, content, query): + response = yield self.handler.on_claim_client_keys(origin, content) defer.returnValue((200, response)) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 739a08ada..718928eed 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -202,7 +202,7 @@ class KeyQueryServlet(RestServlet): for device_id in device_ids: local_query.append((user_id, device_id)) else: - remote_queries.set_default(user.domain, {})[user_id] = list( + remote_queries.setdefault(user.domain, {})[user_id] = list( device_ids ) results = yield self.store.get_e2e_device_keys(local_query) @@ -218,7 +218,7 @@ class KeyQueryServlet(RestServlet): remote_result = yield self.federation.query_client_keys( destination, {"device_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["device_keys"].items(): if user_id in device_keys: json_result[user_id] = keys defer.returnValue((200, {"device_keys": json_result})) @@ -286,7 +286,7 @@ class OneTimeKeyServlet(RestServlet): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: - remote_queries.set_default(user.domain, {})[user_id] = ( + remote_queries.setdefault(user.domain, {})[user_id] = ( device_keys ) results = yield self.store.claim_e2e_one_time_keys(local_query) @@ -300,10 +300,10 @@ class OneTimeKeyServlet(RestServlet): } for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( + remote_result = yield self.federation.claim_client_keys( destination, {"one_time_keys": device_keys} ) - for user_id, keys in remote_result.items(): + for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys From 0cceb2ac92cde0a4289adfc6e9000c7b1c54bdae Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 13 Aug 2015 17:27:46 +0100 Subject: [PATCH 3/3] Add a few strategic new lines to break up the on_query_client_keys and on_claim_client_keys methods in federation_server.py --- synapse/federation/federation_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c32908ac2..725c6f3fa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -323,13 +323,16 @@ class FederationServer(FederationBase): else: for device_id in device_ids: query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): json_result.setdefault(user_id, {})[device_id] = json.loads( json_bytes ) + defer.returnValue({"device_keys": json_result}) @defer.inlineCallbacks @@ -339,7 +342,9 @@ class FederationServer(FederationBase): for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) + results = yield self.store.claim_e2e_one_time_keys(query) + json_result = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): @@ -347,6 +352,7 @@ class FederationServer(FederationBase): json_result.setdefault(user_id, {})[device_id] = { key_id: json.loads(json_bytes) } + defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks