mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 19:43:53 +01:00
Implement minimum_valid_until_ts in the remote key resource
This commit is contained in:
parent
55e1bc8920
commit
46d200a3a1
2 changed files with 56 additions and 4 deletions
|
@ -289,6 +289,7 @@ class Keyring(object):
|
||||||
key_base64 = key_data["key"]
|
key_base64 = key_data["key"]
|
||||||
key_bytes = decode_base64(key_base64)
|
key_bytes = decode_base64(key_base64)
|
||||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||||
|
verify_key.time_added = time_now_ms
|
||||||
verify_keys[key_id] = verify_key
|
verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
old_verify_keys = {}
|
old_verify_keys = {}
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.http.server import request_handler, respond_with_json_bytes
|
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||||
|
from synapse.http.servlet import parse_integer
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
@ -44,7 +45,13 @@ class RemoteKey(Resource):
|
||||||
POST /_matrix/v2/query HTTP/1.1
|
POST /_matrix/v2/query HTTP/1.1
|
||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
{
|
{
|
||||||
"server_keys": { "remote.server.example.com": ["a.key.id"] }
|
"server_keys": {
|
||||||
|
"remote.server.example.com": {
|
||||||
|
"a.key.id": {
|
||||||
|
"minimum_valid_until_ts": 1234567890123
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Response:
|
Response:
|
||||||
|
@ -96,10 +103,16 @@ class RemoteKey(Resource):
|
||||||
def async_render_GET(self, request):
|
def async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
server, = request.postpath
|
server, = request.postpath
|
||||||
query = {server: [None]}
|
query = {server: {}}
|
||||||
elif len(request.postpath) == 2:
|
elif len(request.postpath) == 2:
|
||||||
server, key_id = request.postpath
|
server, key_id = request.postpath
|
||||||
query = {server: [key_id]}
|
minimum_valid_until_ts = parse_integer(
|
||||||
|
request, "minimum_valid_until_ts"
|
||||||
|
)
|
||||||
|
arguments = {}
|
||||||
|
if minimum_valid_until_ts is not None:
|
||||||
|
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
|
||||||
|
query = {server: {key_id: arguments}}
|
||||||
else:
|
else:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
||||||
|
@ -128,8 +141,11 @@ class RemoteKey(Resource):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||||
|
logger.info("Handling query for keys %r", query)
|
||||||
store_queries = []
|
store_queries = []
|
||||||
for server_name, key_ids in query.items():
|
for server_name, key_ids in query.items():
|
||||||
|
if not key_ids:
|
||||||
|
key_ids = (None,)
|
||||||
for key_id in key_ids:
|
for key_id in key_ids:
|
||||||
store_queries.append((server_name, key_id, None))
|
store_queries.append((server_name, key_id, None))
|
||||||
|
|
||||||
|
@ -152,9 +168,44 @@ class RemoteKey(Resource):
|
||||||
if key_id is not None:
|
if key_id is not None:
|
||||||
ts_added_ms, most_recent_result = max(results)
|
ts_added_ms, most_recent_result = max(results)
|
||||||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
||||||
if (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
|
req_key = query.get(server_name, {}).get(key_id, {})
|
||||||
|
req_valid_until = req_key.get("minimum_valid_until_ts")
|
||||||
|
miss = False
|
||||||
|
if req_valid_until is not None:
|
||||||
|
if ts_valid_until_ms < req_valid_until:
|
||||||
|
logger.debug(
|
||||||
|
"Cached response for %r/%r is older than requested"
|
||||||
|
": valid_until (%r) < minimum_valid_until (%r)",
|
||||||
|
server_name, key_id,
|
||||||
|
ts_valid_until_ms, req_valid_until
|
||||||
|
)
|
||||||
|
miss = True
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Cached response for %r/%r is newer than requested"
|
||||||
|
": valid_until (%r) >= minimum_valid_until (%r)",
|
||||||
|
server_name, key_id,
|
||||||
|
ts_valid_until_ms, req_valid_until
|
||||||
|
)
|
||||||
|
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
|
||||||
|
logger.debug(
|
||||||
|
"Cached response for %r/%r is too old"
|
||||||
|
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||||
|
server_name, key_id,
|
||||||
|
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||||
|
)
|
||||||
# We more than half way through the lifetime of the
|
# We more than half way through the lifetime of the
|
||||||
# response. We should fetch a fresh copy.
|
# response. We should fetch a fresh copy.
|
||||||
|
miss = True
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Cached response for %r/%r is still valid"
|
||||||
|
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||||
|
server_name, key_id,
|
||||||
|
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||||
|
)
|
||||||
|
|
||||||
|
if miss:
|
||||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||||
json_results.add(bytes(most_recent_result["key_json"]))
|
json_results.add(bytes(most_recent_result["key_json"]))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue