From f699b8f997ed743af0cfa7046428915a7f42610b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:04:28 +0100 Subject: [PATCH 1/3] Read from DNS cache if within TTL --- synapse/http/endpoint.py | 37 ++++++++++++++++++++++--------------- tests/test_dns.py | 5 ++++- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707..e80d00e2a 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -92,7 +93,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -154,6 +156,12 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(time.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(time.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index 637b1606f..e006ed1a5 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase): service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + cache = { - service_name: [object()] + service_name: [entry] } servers = yield resolve_service( From f9d3665c8841335cd70325dd758b4193c462ca60 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 31 Mar 2016 10:23:48 +0100 Subject: [PATCH 2/3] Allow clock to be passed in to func --- synapse/http/endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index e80d00e2a..bc28a2959 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -155,10 +155,10 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): cache_entry = cache.get(service_name, None) if cache_entry: - if all(s.expires > int(time.time()) for s in cache_entry): + if all(s.expires > int(clock.time()) for s in cache_entry): servers = list(cache_entry) defer.returnValue(servers) @@ -199,7 +199,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): port=int(payload.port), priority=int(payload.priority), weight=int(payload.weight), - expires=int(time.time()) + host_ttl, + expires=int(clock.time()) + host_ttl, )) servers.sort() From 11860637e116717efa14149f17d8b941d1e5db5e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Apr 2016 10:12:30 +0100 Subject: [PATCH 3/3] Tests --- tests/test_dns.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index e006ed1a5..c394c57ee 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -21,6 +21,8 @@ from mock import Mock from synapse.http.endpoint import resolve_service +from tests.utils import MockClock + class DnsTestCase(unittest.TestCase): @@ -63,12 +65,37 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(servers[0].host, ip_address) @defer.inlineCallbacks - def test_from_cache(self): + def test_from_cache_expired_and_dns_fail(self): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) + entry.expires = 0 + + cache = { + service_name: [entry] + } + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_from_cache(self): + clock = MockClock() + + dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock.lookupService = Mock(spec_set=[]) + + service_name = "test_service.examle.com" + entry = Mock(spec_set=["expires"]) entry.expires = 999999999 @@ -77,10 +104,10 @@ class DnsTestCase(unittest.TestCase): } servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache + service_name, dns_client=dns_client_mock, cache=cache, clock=clock, ) - dns_client_mock.lookupService.assert_called_once_with(service_name) + self.assertFalse(dns_client_mock.lookupService.called) self.assertEquals(len(servers), 1) self.assertEquals(servers, cache[service_name])