Merge pull request #3016 from silkeh/improve-service-lookups

Improve handling of SRV records for federation connections
This commit is contained in:
Richard van der Hoff 2018-04-09 23:40:06 +01:00 committed by GitHub
commit 664adb4236
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 122 deletions

View file

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
@ -33,7 +31,7 @@ SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination. # our record of an individual server which can be tried to reach a destination.
# #
# "host" is actually a dotted-quad or ipv6 address string. Except when there's # "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname. # no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port expires" "_Server", "priority weight host port expires"
@ -297,19 +295,12 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
payload = answer.payload payload = answer.payload
hosts = yield _get_hosts_for_srv_record(
dns_client, str(payload.target)
)
for (ip, ttl) in hosts:
host_ttl = min(answer.ttl, ttl)
servers.append(_Server( servers.append(_Server(
host=ip, host=str(payload.target),
port=int(payload.port), port=int(payload.port),
priority=int(payload.priority), priority=int(payload.priority),
weight=int(payload.weight), weight=int(payload.weight),
expires=int(clock.time()) + host_ttl, expires=int(clock.time()) + answer.ttl,
)) ))
servers.sort() servers.sort()
@ -328,81 +319,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e raise e
defer.returnValue(servers) defer.returnValue(servers)
@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []
def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.
return res[0]
def eb(res, record_type):
if res.check(DNSNameError):
return []
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
return res
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(
cb, eb, errbackArgs=("A", ))
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
cb, eb, errbackArgs=("AAAA", ))
results = yield defer.DeferredList(
[d1, d2], consumeErrors=True)
# if all of the lookups failed, raise an exception rather than blowing out
# the cache with an empty result.
if results and all(s == defer.FAILURE for (s, _) in results):
defer.returnValue(results[0][1])
for (success, result) in results:
if success == defer.FAILURE:
continue
for answer in result:
if not answer.payload:
continue
try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue
# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)

View file

@ -33,8 +33,6 @@ class DnsTestCase(unittest.TestCase):
service_name = "test_service.example.com" service_name = "test_service.example.com"
host_name = "example.com" host_name = "example.com"
ip_address = "127.0.0.1"
ip6_address = "::1"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, type=dns.SRV,
@ -43,29 +41,9 @@ class DnsTestCase(unittest.TestCase):
) )
) )
answer_a = dns.RRHeader(
type=dns.A,
payload=dns.Record_A(
address=ip_address,
)
)
answer_aaaa = dns.RRHeader(
type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)
dns_client_mock.lookupService.return_value = defer.succeed( dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None), ([answer_srv], None, None),
) )
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)
cache = {} cache = {}
@ -74,13 +52,10 @@ class DnsTestCase(unittest.TestCase):
) )
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
self.assertEquals(len(servers), 2) self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, host_name)
self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self): def test_from_cache_expired_and_dns_fail(self):