diff --git a/changelog.d/5236.misc b/changelog.d/5236.misc new file mode 100644 index 000000000..cb4417a9f --- /dev/null +++ b/changelog.d/5236.misc @@ -0,0 +1 @@ +Simplify Keyring.process_v2_response. \ No newline at end of file diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index ea910a2f2..9d629b223 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -489,7 +489,7 @@ class Keyring(object): ) processed_response = yield self.process_v2_response( - perspective_name, response + perspective_name, response, time_added_ms=time_now_ms ) server_name = response["server_name"] @@ -541,6 +541,7 @@ class Keyring(object): from_server=server_name, requested_ids=[requested_key_id], response_json=response, + time_added_ms=time_now_ms, ) yield self.store.store_server_verify_keys( server_name, @@ -552,7 +553,9 @@ class Keyring(object): defer.returnValue({server_name: keys}) @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, requested_ids=[]): + def process_v2_response( + self, from_server, response_json, time_added_ms, requested_ids=[] + ): """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -573,6 +576,8 @@ class Keyring(object): response_json (dict): the json-decoded Server Keys response object + time_added_ms (int): the timestamp to record in server_keys_json + requested_ids (iterable[str]): a list of the key IDs that were requested. We will store the json for these key ids as well as any that are actually in the response @@ -581,8 +586,9 @@ class Keyring(object): Deferred[dict[str, nacl.signing.VerifyKey]]: map from key_id to key object """ - time_now_ms = self.clock.time_msec() - response_keys = {} + + # start by extracting the keys from the response, since they may be required + # to validate the signature on the response. verify_keys = {} for key_id, key_data in response_json["verify_keys"].items(): if is_signing_algorithm_supported(key_id): @@ -591,23 +597,27 @@ class Keyring(object): verify_key = decode_verify_key_bytes(key_id, key_bytes) verify_keys[key_id] = verify_key - old_verify_keys = {} + # TODO: improve this signature checking + server_name = response_json["server_name"] + for key_id in response_json["signatures"].get(server_name, {}): + if key_id not in verify_keys: + raise KeyLookupError( + "Key response must include verification keys for all signatures" + ) + + verify_signed_json( + response_json, server_name, verify_keys[key_id] + ) + for key_id, key_data in response_json["old_verify_keys"].items(): if is_signing_algorithm_supported(key_id): key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - old_verify_keys[key_id] = verify_key - - server_name = response_json["server_name"] - for key_id in response_json["signatures"].get(server_name, {}): - if key_id not in response_json["verify_keys"]: - raise KeyLookupError( - "Key response must include verification keys for all" " signatures" - ) - if key_id in verify_keys: - verify_signed_json(response_json, server_name, verify_keys[key_id]) + verify_keys[key_id] = verify_key + # re-sign the json with our own key, so that it is ready if we are asked to + # give it out as a notary server signed_key_json = sign_json( response_json, self.config.server_name, self.config.signing_key[0] ) @@ -615,12 +625,10 @@ class Keyring(object): signed_key_json_bytes = encode_canonical_json(signed_key_json) ts_valid_until_ms = signed_key_json[u"valid_until_ts"] + # for reasons I don't quite understand, we store this json for the key ids we + # requested, as well as those we got. updated_key_ids = set(requested_ids) updated_key_ids.update(verify_keys) - updated_key_ids.update(old_verify_keys) - - response_keys.update(verify_keys) - response_keys.update(old_verify_keys) yield logcontext.make_deferred_yieldable( defer.gatherResults( @@ -630,7 +638,7 @@ class Keyring(object): server_name=server_name, key_id=key_id, from_server=from_server, - ts_now_ms=time_now_ms, + ts_now_ms=time_added_ms, ts_expires_ms=ts_valid_until_ms, key_json_bytes=signed_key_json_bytes, ) @@ -640,7 +648,7 @@ class Keyring(object): ).addErrback(unwrapFirstError) ) - defer.returnValue(response_keys) + defer.returnValue(verify_keys) @defer.inlineCallbacks