mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-07 14:14:46 +01:00
Merge pull request #208 from matrix-org/markjh/end-to-end-key-federation
Federation for end-to-end key requests.
This commit is contained in:
commit
8899df13bf
5 changed files with 233 additions and 30 deletions
|
@ -134,6 +134,36 @@ class FederationClient(FederationBase):
|
||||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def query_client_keys(self, destination, content):
|
||||||
|
"""Query device keys for a device hosted on a remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): Domain name of the remote homeserver
|
||||||
|
content (dict): The query content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Deferred which will eventually yield a JSON object from the
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
sent_queries_counter.inc("client_device_keys")
|
||||||
|
return self.transport_layer.query_client_keys(destination, content)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def claim_client_keys(self, destination, content):
|
||||||
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): Domain name of the remote homeserver
|
||||||
|
content (dict): The query content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Deferred which will eventually yield a JSON object from the
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
sent_queries_counter.inc("client_one_time_keys")
|
||||||
|
return self.transport_layer.claim_client_keys(destination, content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def backfill(self, dest, context, limit, extremities):
|
def backfill(self, dest, context, limit, extremities):
|
||||||
|
|
|
@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
|
||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -312,6 +313,48 @@ class FederationServer(FederationBase):
|
||||||
(200, send_content)
|
(200, send_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def on_query_client_keys(self, origin, content):
|
||||||
|
query = []
|
||||||
|
for user_id, device_ids in content.get("device_keys", {}).items():
|
||||||
|
if not device_ids:
|
||||||
|
query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
query.append((user_id, device_id))
|
||||||
|
|
||||||
|
results = yield self.store.get_e2e_device_keys(query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, json_bytes in device_keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||||
|
json_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({"device_keys": json_result})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def on_claim_client_keys(self, origin, content):
|
||||||
|
query = []
|
||||||
|
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||||
|
for device_id, algorithm in device_keys.items():
|
||||||
|
query.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
|
results = yield self.store.claim_e2e_one_time_keys(query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, keys in device_keys.items():
|
||||||
|
for key_id, json_bytes in keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = {
|
||||||
|
key_id: json.loads(json_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue({"one_time_keys": json_result})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||||
|
|
|
@ -222,6 +222,76 @@ class TransportLayerClient(object):
|
||||||
|
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def query_client_keys(self, destination, query_content):
|
||||||
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
|
server.
|
||||||
|
|
||||||
|
Request:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
} }
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {...}
|
||||||
|
} } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination(str): The server to query.
|
||||||
|
query_content(dict): The user ids to query.
|
||||||
|
Returns:
|
||||||
|
A dict containg the device keys.
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/user/keys/query"
|
||||||
|
|
||||||
|
content = yield self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=query_content,
|
||||||
|
)
|
||||||
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def claim_client_keys(self, destination, query_content):
|
||||||
|
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||||
|
|
||||||
|
Request:
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": "<algorithm>"
|
||||||
|
} } }
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
} } } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination(str): The server to query.
|
||||||
|
query_content(dict): The user ids to query.
|
||||||
|
Returns:
|
||||||
|
A dict containg the one-time keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/user/keys/claim"
|
||||||
|
|
||||||
|
content = yield self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=query_content,
|
||||||
|
)
|
||||||
|
defer.returnValue(content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_missing_events(self, destination, room_id, earliest_events,
|
def get_missing_events(self, destination, room_id, earliest_events,
|
||||||
|
|
|
@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
|
||||||
defer.returnValue((200, content))
|
defer.returnValue((200, content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||||
|
PATH = "/user/keys/query"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query):
|
||||||
|
response = yield self.handler.on_query_client_keys(origin, content)
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||||
|
PATH = "/user/keys/claim"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query):
|
||||||
|
response = yield self.handler.on_claim_client_keys(origin, content)
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
|
||||||
class FederationQueryAuthServlet(BaseFederationServlet):
|
class FederationQueryAuthServlet(BaseFederationServlet):
|
||||||
PATH = "/query_auth/([^/]*)/([^/]*)"
|
PATH = "/query_auth/([^/]*)/([^/]*)"
|
||||||
|
|
||||||
|
@ -373,4 +391,6 @@ SERVLET_CLASSES = (
|
||||||
FederationQueryAuthServlet,
|
FederationQueryAuthServlet,
|
||||||
FederationGetMissingEventsServlet,
|
FederationGetMissingEventsServlet,
|
||||||
FederationEventAuthServlet,
|
FederationEventAuthServlet,
|
||||||
|
FederationClientKeysQueryServlet,
|
||||||
|
FederationClientKeysClaimServlet,
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
from synapse.types import UserID
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
from ._base import client_v2_pattern
|
from ._base import client_v2_pattern
|
||||||
|
@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
|
||||||
super(KeyQueryServlet, self).__init__()
|
super(KeyQueryServlet, self).__init__()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id):
|
def on_POST(self, request, user_id, device_id):
|
||||||
logger.debug("onPOST")
|
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
try:
|
try:
|
||||||
body = json.loads(request.content.read())
|
body = json.loads(request.content.read())
|
||||||
except:
|
except:
|
||||||
raise SynapseError(400, "Invalid key JSON")
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
query = []
|
result = yield self.handle_request(body)
|
||||||
for user_id, device_ids in body.get("device_keys", {}).items():
|
defer.returnValue(result)
|
||||||
if not device_ids:
|
|
||||||
query.append((user_id, None))
|
|
||||||
else:
|
|
||||||
for device_id in device_ids:
|
|
||||||
query.append((user_id, device_id))
|
|
||||||
results = yield self.store.get_e2e_device_keys(query)
|
|
||||||
defer.returnValue(self.json_result(request, results))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id):
|
def on_GET(self, request, user_id, device_id):
|
||||||
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
auth_user_id = auth_user.to_string()
|
auth_user_id = auth_user.to_string()
|
||||||
if not user_id:
|
user_id = user_id if user_id else auth_user_id
|
||||||
user_id = auth_user_id
|
device_ids = [device_id] if device_id else []
|
||||||
if not device_id:
|
result = yield self.handle_request(
|
||||||
device_id = None
|
{"device_keys": {user_id: device_ids}}
|
||||||
# Returns a map of user_id->device_id->json_bytes.
|
)
|
||||||
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
|
defer.returnValue(result)
|
||||||
defer.returnValue(self.json_result(request, results))
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request(self, body):
|
||||||
|
local_query = []
|
||||||
|
remote_queries = {}
|
||||||
|
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
if self.is_mine(user):
|
||||||
|
if not device_ids:
|
||||||
|
local_query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
local_query.append((user_id, device_id))
|
||||||
|
else:
|
||||||
|
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
||||||
|
device_ids
|
||||||
|
)
|
||||||
|
results = yield self.store.get_e2e_device_keys(local_query)
|
||||||
|
|
||||||
def json_result(self, request, results):
|
|
||||||
json_result = {}
|
json_result = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, json_bytes in device_keys.items():
|
for device_id, json_bytes in device_keys.items():
|
||||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||||
json_bytes
|
json_bytes
|
||||||
)
|
)
|
||||||
return (200, {"device_keys": json_result})
|
|
||||||
|
for destination, device_keys in remote_queries.items():
|
||||||
|
remote_result = yield self.federation.query_client_keys(
|
||||||
|
destination, {"device_keys": device_keys}
|
||||||
|
)
|
||||||
|
for user_id, keys in remote_result["device_keys"].items():
|
||||||
|
if user_id in device_keys:
|
||||||
|
json_result[user_id] = keys
|
||||||
|
defer.returnValue((200, {"device_keys": json_result}))
|
||||||
|
|
||||||
|
|
||||||
class OneTimeKeyServlet(RestServlet):
|
class OneTimeKeyServlet(RestServlet):
|
||||||
|
@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id, algorithm):
|
def on_GET(self, request, user_id, device_id, algorithm):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request)
|
||||||
results = yield self.store.claim_e2e_one_time_keys(
|
result = yield self.handle_request(
|
||||||
[(user_id, device_id, algorithm)]
|
{"one_time_keys": {user_id: {device_id: algorithm}}}
|
||||||
)
|
)
|
||||||
defer.returnValue(self.json_result(request, results))
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id, algorithm):
|
def on_POST(self, request, user_id, device_id, algorithm):
|
||||||
|
@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
body = json.loads(request.content.read())
|
body = json.loads(request.content.read())
|
||||||
except:
|
except:
|
||||||
raise SynapseError(400, "Invalid key JSON")
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
query = []
|
result = yield self.handle_request(body)
|
||||||
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
defer.returnValue(result)
|
||||||
for device_id, algorithm in device_keys.items():
|
|
||||||
query.append((user_id, device_id, algorithm))
|
@defer.inlineCallbacks
|
||||||
results = yield self.store.claim_e2e_one_time_keys(query)
|
def handle_request(self, body):
|
||||||
defer.returnValue(self.json_result(request, results))
|
local_query = []
|
||||||
|
remote_queries = {}
|
||||||
|
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
if self.is_mine(user):
|
||||||
|
for device_id, algorithm in device_keys.items():
|
||||||
|
local_query.append((user_id, device_id, algorithm))
|
||||||
|
else:
|
||||||
|
remote_queries.setdefault(user.domain, {})[user_id] = (
|
||||||
|
device_keys
|
||||||
|
)
|
||||||
|
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
||||||
|
|
||||||
def json_result(self, request, results):
|
|
||||||
json_result = {}
|
json_result = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
for device_id, keys in device_keys.items():
|
for device_id, keys in device_keys.items():
|
||||||
|
@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
json_result.setdefault(user_id, {})[device_id] = {
|
json_result.setdefault(user_id, {})[device_id] = {
|
||||||
key_id: json.loads(json_bytes)
|
key_id: json.loads(json_bytes)
|
||||||
}
|
}
|
||||||
return (200, {"one_time_keys": json_result})
|
|
||||||
|
for destination, device_keys in remote_queries.items():
|
||||||
|
remote_result = yield self.federation.claim_client_keys(
|
||||||
|
destination, {"one_time_keys": device_keys}
|
||||||
|
)
|
||||||
|
for user_id, keys in remote_result["one_time_keys"].items():
|
||||||
|
if user_id in device_keys:
|
||||||
|
json_result[user_id] = keys
|
||||||
|
|
||||||
|
defer.returnValue((200, {"one_time_keys": json_result}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
|
Loading…
Reference in a new issue