0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-07-01 16:28:19 +02:00

ircd::net::dns Include query in callback arguments to prevent any stale captures.

This commit is contained in:
Jason Volk 2018-05-01 18:40:03 -07:00
parent 39ad36e3ed
commit fb53069c6f
6 changed files with 76 additions and 68 deletions

View file

@ -35,10 +35,10 @@ struct ircd::net::dns
struct resolver static *resolver; struct resolver static *resolver;
struct opts static const opts_default; struct opts static const opts_default;
using callback = std::function<void (std::exception_ptr, vector_view<const rfc1035::record *>)>; using callback = std::function<void (std::exception_ptr, const hostport &, vector_view<const rfc1035::record *>)>;
using callback_A_one = std::function<void (std::exception_ptr, const rfc1035::record::A &)>; using callback_A_one = std::function<void (std::exception_ptr, const hostport &, const rfc1035::record::A &)>;
using callback_SRV_one = std::function<void (std::exception_ptr, const rfc1035::record::SRV &)>; using callback_SRV_one = std::function<void (std::exception_ptr, const hostport &, const rfc1035::record::SRV &)>;
using callback_ipport_one = std::function<void (std::exception_ptr, const ipport &)>; using callback_ipport_one = std::function<void (std::exception_ptr, const hostport &, const ipport &)>;
// (internal) generate strings for rfc1035 questions or dns::cache keys. // (internal) generate strings for rfc1035 questions or dns::cache keys.
static string_view make_SRV_key(const mutable_buffer &out, const hostport &, const opts &); static string_view make_SRV_key(const mutable_buffer &out, const hostport &, const opts &);

View file

@ -52,7 +52,7 @@ struct ircd::net::dns::resolver
void send_query(const const_buffer &, tag &); void send_query(const const_buffer &, tag &);
void submit(const const_buffer &, tag &); void submit(const const_buffer &, tag &);
tag &set_tag(tag &&); template<class... A> tag &set_tag(A&&...);
const_buffer make_query(const mutable_buffer &buf, const tag &) const; const_buffer make_query(const mutable_buffer &buf, const tag &) const;
void operator()(const hostport &, const opts &, callback); void operator()(const hostport &, const opts &, callback);
@ -72,20 +72,23 @@ struct ircd::net::dns::resolver
struct ircd::net::dns::resolver::tag struct ircd::net::dns::resolver::tag
{ {
uint16_t id {0}; uint16_t id {0};
hostport hp; // note: invalid after query sent hostport hp;
dns::opts opts; // note: invalid after query sent dns::opts opts; // note: invalid after query sent
callback cb; callback cb;
steady_point last; steady_point last;
uint8_t tries {0}; uint8_t tries {0};
char hostbuf[256];
tag(const hostport &, const dns::opts &, callback); tag(const hostport &, const dns::opts &, callback &&);
}; };
inline inline
ircd::net::dns::resolver::tag::tag(const hostport &hp, ircd::net::dns::resolver::tag::tag(const hostport &hp,
const dns::opts &opts, const dns::opts &opts,
callback cb) callback &&cb)
:hp{hp} :hp{hp}
,opts{opts} ,opts{opts}
,cb{std::move(cb)} ,cb{std::move(cb)}
{} {
this->hp.host = { hostbuf, copy(hostbuf, hp.host) };
}

View file

@ -41,7 +41,7 @@ struct ircd::server::peer
template<class F> size_t accumulate_tags(F&&) const; template<class F> size_t accumulate_tags(F&&) const;
void handle_finished(); void handle_finished();
void handle_resolve(std::exception_ptr, const ipport &); void handle_resolve(std::exception_ptr, const hostport &, const ipport &);
void resolve(const hostport &); void resolve(const hostport &);
void disperse_uncommitted(link &); void disperse_uncommitted(link &);

View file

@ -547,7 +547,7 @@ ircd::net::open(socket &socket,
}}; }};
auto connector{[&socket, opts, complete(std::move(complete))] auto connector{[&socket, opts, complete(std::move(complete))]
(std::exception_ptr eptr, const ipport &ipport) (std::exception_ptr eptr, const hostport &hp, const ipport &ipport)
{ {
if(eptr) if(eptr)
return complete(std::move(eptr)); return complete(std::move(eptr));
@ -559,7 +559,7 @@ ircd::net::open(socket &socket,
if(!opts.ipport) if(!opts.ipport)
dns(opts.hostport, std::move(connector)); dns(opts.hostport, std::move(connector));
else else
connector({}, opts.ipport); connector({}, opts.hostport, opts.ipport);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -2361,21 +2361,21 @@ ircd::net::dns::cache::min_ttl
decltype(ircd::net::dns::prefetch_ipport) decltype(ircd::net::dns::prefetch_ipport)
ircd::net::dns::prefetch_ipport{[] ircd::net::dns::prefetch_ipport{[]
(std::exception_ptr, const auto &record) (std::exception_ptr, const auto &hostport, const auto &record)
{ {
// Do nothing; cache already updated if necessary // Do nothing; cache already updated if necessary
}}; }};
decltype(ircd::net::dns::prefetch_SRV) decltype(ircd::net::dns::prefetch_SRV)
ircd::net::dns::prefetch_SRV{[] ircd::net::dns::prefetch_SRV{[]
(std::exception_ptr, const auto &record) (std::exception_ptr, const auto &hostport, const auto &record)
{ {
// Do nothing; cache already updated if necessary // Do nothing; cache already updated if necessary
}}; }};
decltype(ircd::net::dns::prefetch_A) decltype(ircd::net::dns::prefetch_A)
ircd::net::dns::prefetch_A{[] ircd::net::dns::prefetch_A{[]
(std::exception_ptr, const auto &record) (std::exception_ptr, const auto &hostport, const auto &record)
{ {
// Do nothing; cache already updated if necessary // Do nothing; cache already updated if necessary
}}; }};
@ -2390,51 +2390,48 @@ ircd::net::dns::operator()(const hostport &hp,
{ {
//TODO: ip6 //TODO: ip6
auto calluser{[callback(std::move(callback))] auto calluser{[callback(std::move(callback))]
(std::exception_ptr eptr, const uint32_t &ip, const uint16_t &port) (std::exception_ptr eptr, const hostport &hp, const uint32_t &ip)
{ {
if(eptr) if(eptr)
return callback(std::move(eptr), {}); return callback(std::move(eptr), hp, {});
if(!ip) if(!ip)
return callback(std::make_exception_ptr(net::not_found{"Host has no A record"}), {}); return callback(std::make_exception_ptr(net::not_found{"Host has no A record"}), hp, {});
const ipport ipport{ip, port}; const ipport ipport{ip, port(hp)};
callback(std::move(eptr), ipport); callback(std::move(eptr), hp, ipport);
}}; }};
if(!hp.service) if(!hp.service)
return operator()(hp, opts, [hp, calluser(std::move(calluser))] return operator()(hp, opts, [calluser(std::move(calluser))]
(std::exception_ptr eptr, const rfc1035::record::A &record) (std::exception_ptr eptr, const hostport &hp, const rfc1035::record::A &record)
{ {
calluser(std::move(eptr), record.ip4, port(hp)); calluser(std::move(eptr), hp, record.ip4);
}); });
auto srv_opts{opts}; auto srv_opts{opts};
srv_opts.nxdomain_exceptions = false; srv_opts.nxdomain_exceptions = false;
operator()(hp, srv_opts, [this, hp(hp), opts(opts), calluser(std::move(calluser))] operator()(hp, srv_opts, [this, opts(opts), calluser(std::move(calluser))]
(std::exception_ptr eptr, const rfc1035::record::SRV &record) (std::exception_ptr eptr, hostport hp, const rfc1035::record::SRV &record)
mutable mutable
{ {
if(eptr) if(eptr)
return calluser(std::move(eptr), 0, 0); return calluser(std::move(eptr), hp, 0);
if(record.port != 0) if(record.port != 0)
hp.port = record.port; hp.port = record.port;
// The host reference in hp has become invalid by the time of this cb. hp.host = record.tgt?: unmake_SRV_key(hp.host);
assert(!record.tgt.empty());
hp.host = record.tgt;
// Have to kill the service name to not run another SRV query now, and // Have to kill the service name to not run another SRV query now.
// these are also invalid references too.
hp.service = {}; hp.service = {};
opts.srv = {}; opts.srv = {};
opts.proto = {}; opts.proto = {};
this->operator()(hp, opts, [hp, calluser(std::move(calluser))] this->operator()(hp, opts, [calluser(std::move(calluser))]
(std::exception_ptr eptr, const rfc1035::record::A &record) (std::exception_ptr eptr, const hostport &hp, const rfc1035::record::A &record)
{ {
calluser(std::move(eptr), record.ip4, port(hp)); calluser(std::move(eptr), hp, record.ip4);
}); });
}); });
} }
@ -2442,16 +2439,16 @@ ircd::net::dns::operator()(const hostport &hp,
/// Convenience callback with a single SRV record which was selected from /// Convenience callback with a single SRV record which was selected from
/// the vector with stochastic respect for weighting and priority. /// the vector with stochastic respect for weighting and priority.
void void
ircd::net::dns::operator()(const hostport &hostport, ircd::net::dns::operator()(const hostport &hp,
const opts &opts, const opts &opts,
callback_SRV_one callback) callback_SRV_one callback)
{ {
assert(bool(ircd::net::dns::resolver)); assert(bool(ircd::net::dns::resolver));
operator()(hostport, opts, [callback(std::move(callback))] operator()(hp, opts, [callback(std::move(callback))]
(std::exception_ptr eptr, const vector_view<const rfc1035::record *> rrs) (std::exception_ptr eptr, const hostport &hp, const vector_view<const rfc1035::record *> rrs)
{ {
if(eptr) if(eptr)
return callback(std::move(eptr), {}); return callback(std::move(eptr), hp, {});
//TODO: prng on weight / prio plz //TODO: prng on weight / prio plz
for(size_t i(0); i < rrs.size(); ++i) for(size_t i(0); i < rrs.size(); ++i)
@ -2461,26 +2458,26 @@ ircd::net::dns::operator()(const hostport &hostport,
continue; continue;
const auto &record(rr.as<const rfc1035::record::SRV>()); const auto &record(rr.as<const rfc1035::record::SRV>());
return callback(std::move(eptr), record); return callback(std::move(eptr), hp, record);
} }
return callback(std::move(eptr), {}); return callback(std::move(eptr), hp, {});
}); });
} }
/// Convenience callback with a single A record which was selected from /// Convenience callback with a single A record which was selected from
/// the vector randomly. /// the vector randomly.
void void
ircd::net::dns::operator()(const hostport &hostport, ircd::net::dns::operator()(const hostport &hp,
const opts &opts, const opts &opts,
callback_A_one callback) callback_A_one callback)
{ {
assert(bool(ircd::net::dns::resolver)); assert(bool(ircd::net::dns::resolver));
operator()(hostport, opts, [callback(std::move(callback))] operator()(hp, opts, [callback(std::move(callback))]
(std::exception_ptr eptr, const vector_view<const rfc1035::record *> rrs) (std::exception_ptr eptr, const hostport &hp, const vector_view<const rfc1035::record *> rrs)
{ {
if(eptr) if(eptr)
return callback(std::move(eptr), {}); return callback(std::move(eptr), hp, {});
//TODO: prng plz //TODO: prng plz
for(size_t i(0); i < rrs.size(); ++i) for(size_t i(0); i < rrs.size(); ++i)
@ -2490,10 +2487,10 @@ ircd::net::dns::operator()(const hostport &hostport,
continue; continue;
const auto &record(rr.as<const rfc1035::record::A>()); const auto &record(rr.as<const rfc1035::record::A>());
return callback(std::move(eptr), record); return callback(std::move(eptr), hp, record);
} }
return callback(std::move(eptr), {}); return callback(std::move(eptr), hp, {});
}); });
} }
@ -2602,7 +2599,6 @@ ircd::net::dns::cache::put_error(const rfc1035::question &question,
rfc1035::record::SRV record; rfc1035::record::SRV record;
record.ttl = ircd::time() + seconds(cache::clear_nxdomain).count(); //TODO: code record.ttl = ircd::time() + seconds(cache::clear_nxdomain).count(); //TODO: code
it = map.emplace_hint(it, host, record); it = map.emplace_hint(it, host, record);
it->second.tgt = unmake_SRV_key(it->first);
return &it->second; return &it->second;
} }
} }
@ -2795,7 +2791,7 @@ ircd::net::dns::cache::get(const hostport &hp,
assert(!eptr || count == 1); // if error, should only be one entry. assert(!eptr || count == 1); // if error, should only be one entry.
if(count) if(count)
cb(std::move(eptr), vector_view<const rfc1035::record *>(record.data(), count)); cb(std::move(eptr), hp, vector_view<const rfc1035::record *>(record.data(), count));
return count; return count;
} }
@ -2955,7 +2951,8 @@ ircd::net::dns::resolver::check_timeout(const uint16_t &id,
log.error("DNS timeout id:%u", id); log.error("DNS timeout id:%u", id);
// Callback gets a fresh stack off this timeout worker ctx's stack. // Callback gets a fresh stack off this timeout worker ctx's stack.
if(tag.cb) ircd::post([cb(std::move(tag.cb))] std::string host{tag.hp.host};
if(tag.cb) ircd::post([cb(std::move(tag.cb)), host(std::move(host)), port(tag.hp.port)]
{ {
using boost::system::system_error; using boost::system::system_error;
const error_code ec const error_code ec
@ -2963,7 +2960,8 @@ ircd::net::dns::resolver::check_timeout(const uint16_t &id,
boost::system::errc::timed_out, boost::system::system_category() boost::system::errc::timed_out, boost::system::system_category()
}; };
cb(std::make_exception_ptr(system_error{ec}), {}); const hostport hp{host, port};
cb(std::make_exception_ptr(system_error{ec}), hp, {});
}); });
return false; return false;
@ -2971,16 +2969,13 @@ ircd::net::dns::resolver::check_timeout(const uint16_t &id,
/// Internal resolver entry interface. /// Internal resolver entry interface.
void void
ircd::net::dns::resolver::operator()(const hostport &hostport, ircd::net::dns::resolver::operator()(const hostport &hp,
const opts &opts, const opts &opts,
callback callback) callback callback)
{ {
auto &tag auto &tag
{ {
set_tag(resolver::tag set_tag(hp, opts, std::move(callback))
{
hostport, opts, std::move(callback)
})
}; };
// Escape trunk // Escape trunk
@ -3004,7 +2999,7 @@ const
thread_local char srvbuf[512]; thread_local char srvbuf[512];
const string_view srvhost const string_view srvhost
{ {
make_SRV_key(srvbuf, tag.hp, tag.opts) make_SRV_key(srvbuf, host(tag.hp), tag.opts)
}; };
const rfc1035::question question{srvhost, "SRV"}; const rfc1035::question question{srvhost, "SRV"};
@ -3015,17 +3010,22 @@ const
return rfc1035::make_query(buf, tag.id, question); return rfc1035::make_query(buf, tag.id, question);
} }
template<class... A>
ircd::net::dns::resolver::tag & ircd::net::dns::resolver::tag &
ircd::net::dns::resolver::set_tag(tag &&tag) ircd::net::dns::resolver::set_tag(A&&... args)
{ {
while(tags.size() < 65535) while(tags.size() < 65535)
{ {
tag.id = ircd::rand::integer(1, 65535); auto id(ircd::rand::integer(1, 65535));
auto it{tags.lower_bound(tag.id)}; auto it{tags.lower_bound(id)};
if(it != end(tags) && it->first == tag.id) if(it != end(tags) && it->first == id)
continue; continue;
it = tags.emplace_hint(it, tag.id, std::move(tag)); it = tags.emplace_hint(it,
std::piecewise_construct,
std::forward_as_tuple(id),
std::forward_as_tuple(std::forward<A>(args)...));
it->second.id = id;
dock.notify_one(); dock.notify_one();
return it->second; return it->second;
} }
@ -3288,8 +3288,12 @@ try
} }
} }
// Cache no answers here.
if(!header.ancount && tag.opts.cache_result)
cache.put_error(qd.at(0), header.rcode);
if(tag.cb) if(tag.cb)
tag.cb({}, vector_view<const rfc1035::record *>(record, i)); tag.cb({}, tag.hp, vector_view<const rfc1035::record *>(record, i));
} }
catch(const std::exception &e) catch(const std::exception &e)
{ {
@ -3302,8 +3306,8 @@ catch(const std::exception &e)
if(tag.cb) if(tag.cb)
{ {
assert(tag.opts.nxdomain_exceptions); assert(header.rcode != 3 || tag.opts.nxdomain_exceptions);
tag.cb(std::current_exception(), {}); tag.cb(std::current_exception(), tag.hp, {});
} }
} }
@ -3334,7 +3338,7 @@ ircd::net::dns::resolver::handle_error(const header &header,
if(!tag.opts.nxdomain_exceptions && tag.cb) if(!tag.opts.nxdomain_exceptions && tag.cb)
{ {
assert(record); assert(record);
tag.cb({}, vector_view<const rfc1035::record *>(&record, 1)); tag.cb({}, tag.hp, vector_view<const rfc1035::record *>(&record, 1));
tag.cb = {}; tag.cb = {};
} }

View file

@ -851,7 +851,7 @@ ircd::server::peer::resolve(const hostport &hostport)
auto handler auto handler
{ {
std::bind(&peer::handle_resolve, this, ph::_1, ph::_2) std::bind(&peer::handle_resolve, this, ph::_1, ph::_2, ph::_3)
}; };
op_resolve = true; op_resolve = true;
@ -860,6 +860,7 @@ ircd::server::peer::resolve(const hostport &hostport)
void void
ircd::server::peer::handle_resolve(std::exception_ptr eptr, ircd::server::peer::handle_resolve(std::exception_ptr eptr,
const hostport &hp,
const ipport &ipport) const ipport &ipport)
try try
{ {
@ -877,7 +878,7 @@ try
open_opts.ipport = this->remote; open_opts.ipport = this->remote;
host(open_opts.hostport) = this->hostname; host(open_opts.hostport) = this->hostname;
port(open_opts.hostport) = port(ipport); port(open_opts.hostport) = port(ipport);
open_opts.common_name = this->hostname; open_opts.common_name = {};
if(unlikely(finished())) if(unlikely(finished()))
return handle_finished(); return handle_finished();

View file

@ -1693,7 +1693,7 @@ console_cmd__net__host(opt &out, const string_view &line)
net::ipport ipport; net::ipport ipport;
std::exception_ptr eptr; std::exception_ptr eptr;
net::dns(hostport, [&done, &dock, &eptr, &ipport] net::dns(hostport, [&done, &dock, &eptr, &ipport]
(std::exception_ptr eptr_, const net::ipport &ipport_) (std::exception_ptr eptr_, const net::hostport &, const net::ipport &ipport_)
{ {
eptr = std::move(eptr_); eptr = std::move(eptr_);
ipport = ipport_; ipport = ipport_;