From 8af781c5d792a3b30f3c073019e479d8d9a480b1 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 27 Apr 2023 17:58:02 -0700 Subject: [PATCH] modules/client/keys/query: Handle local users without loopbacking. --- modules/client/keys/query.cc | 138 ++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 11 deletions(-) diff --git a/modules/client/keys/query.cc b/modules/client/keys/query.cc index c51b3ae52..ff30ccc7b 100644 --- a/modules/client/keys/query.cc +++ b/modules/client/keys/query.cc @@ -21,20 +21,22 @@ namespace static void handle_cross_keys(const m::resource::request &, + const user_devices_map &, query_map &, failure_map &, json::stack::object &, - const string_view &, - const bool match_user = false); + const string_view &); static void handle_device_keys(const m::resource::request &, + const user_devices_map &, query_map &, failure_map &, json::stack::object &); static void handle_responses(const m::resource::request &, + const host_users_map &, query_map &, failure_map &, json::stack::object &); @@ -188,7 +190,7 @@ post__keys_query(client &client, out }; - handle_responses(request, queries, failures, top); + handle_responses(request, map, queries, failures, top); handle_failures(failures, top); return response; } @@ -244,7 +246,8 @@ send_requests(const host_users_map &hosts, { query_map ret; for(const auto &[remote, user_devices] : hosts) - send_request(remote, user_devices, failures, buffers, ret); + if(likely(!my_host(remote))) + send_request(remote, user_devices, failures, buffers, ret); return ret; } @@ -304,15 +307,28 @@ catch(const std::exception &e) void handle_responses(const m::resource::request &request, + const host_users_map &map, query_map &queries, failure_map &failures, json::stack::object &out) { + static const user_devices_map empty; + + const auto it + { + map.find(origin(m::my())) + }; + + const user_devices_map &self + { + it != end(map)? it->second: empty + }; + handle_errors(request, queries, failures); - handle_device_keys(request, queries, failures, out); - handle_cross_keys(request, queries, failures, out, "master_keys"); - handle_cross_keys(request, queries, failures, out, "self_signing_keys"); - handle_cross_keys(request, queries, failures, out, "user_signing_keys", true); + handle_device_keys(request, self, queries, failures, out); + handle_cross_keys(request, self, queries, failures, out, "master_keys"); + handle_cross_keys(request, self, queries, failures, out, "self_signing_keys"); + handle_cross_keys(request, self, queries, failures, out, "user_signing_keys"); } void @@ -335,6 +351,7 @@ handle_errors(const m::resource::request &request, void handle_device_keys(const m::resource::request &request, + const user_devices_map &self, query_map &queries, failure_map &failures, json::stack::object &out) @@ -344,7 +361,50 @@ handle_device_keys(const m::resource::request &request, out, "device_keys" }; - for(auto &[remote, query] : queries) try + // local handle + for(const auto &[user_id, device_ids] : self) + { + const m::user::keys keys + { + user_id + }; + + json::stack::object user_object + { + object, user_id + }; + + if(empty(json::array(device_ids))) + { + const m::user::devices devices + { + user_id + }; + + devices.for_each([&user_object, &keys] + (const auto &, const string_view &device_id) + { + json::stack::object device_object + { + user_object, device_id + }; + + keys.device(device_object, device_id); + }); + } + else for(const json::string device_id : json::array(device_ids)) + { + json::stack::object device_object + { + user_object, device_id + }; + + keys.device(device_object, device_id); + } + } + + // remote handle + for(const auto &[remote, query] : queries) try { const json::object response { @@ -389,19 +449,75 @@ handle_device_keys(const m::resource::request &request, } } +static std::tuple +translate_cross_type(const string_view &name) +{ + bool match_user; + string_view cross_type; + switch(match_user = false; hash(name)) + { + case "master_keys"_: + cross_type = "ircd.cross_signing.master"; + break; + + case "self_signing_keys"_: + cross_type = "ircd.cross_signing.self"; + break; + + case "user_signing_keys"_: + cross_type = "ircd.cross_signing.user"; + match_user = true; + break; + }; + + assert(cross_type); + return + { + cross_type, match_user + }; +} + void handle_cross_keys(const m::resource::request &request, + const user_devices_map &self, query_map &queries, failure_map &failures, json::stack::object &out_, - const string_view &name, - const bool match_user) + const string_view &name) { + const auto &[cross_type, match_user] + { + translate_cross_type(name) + }; + json::stack::object out { out_, name }; + // local handle + for(const auto &[user_id, device_ids] : self) + { + if(match_user && request.user_id != user_id) + continue; + + const m::user::keys keys + { + user_id + }; + + if(!keys.has_cross(cross_type)) + continue; + + json::stack::object user_object + { + out, user_id + }; + + keys.cross(user_object, cross_type); + } + + // remote handle for(auto &[remote, query] : queries) try { if(match_user && request.user_id.host() != remote)