forked from MirrorHub/synapse
Read from DNS cache if within TTL
This commit is contained in:
parent
a68c1b15aa
commit
f699b8f997
2 changed files with 26 additions and 16 deletions
|
@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
|
|||
import collections
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,7 +32,7 @@ SERVER_CACHE = {}
|
|||
|
||||
|
||||
_Server = collections.namedtuple(
|
||||
"_Server", "priority weight host port"
|
||||
"_Server", "priority weight host port expires"
|
||||
)
|
||||
|
||||
|
||||
|
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
|
|||
host=domain,
|
||||
port=default_port,
|
||||
priority=0,
|
||||
weight=0
|
||||
weight=0,
|
||||
expires=0,
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
@ -154,6 +156,12 @@ class SRVClientEndpoint(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(time.time()) for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
servers = []
|
||||
|
||||
try:
|
||||
|
@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
|||
continue
|
||||
|
||||
payload = answer.payload
|
||||
|
||||
host = str(payload.target)
|
||||
srv_ttl = answer.ttl
|
||||
|
||||
try:
|
||||
answers, _, _ = yield dns_client.lookupAddress(host)
|
||||
except DNSNameError:
|
||||
continue
|
||||
|
||||
ips = [
|
||||
answer.payload.dottedQuad()
|
||||
for answer in answers
|
||||
if answer.type == dns.A and answer.payload
|
||||
]
|
||||
for answer in answers:
|
||||
if answer.type == dns.A and answer.payload:
|
||||
ip = answer.payload.dottedQuad()
|
||||
host_ttl = min(srv_ttl, answer.ttl)
|
||||
|
||||
for ip in ips:
|
||||
servers.append(_Server(
|
||||
host=ip,
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight)
|
||||
))
|
||||
servers.append(_Server(
|
||||
host=ip,
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight),
|
||||
expires=int(time.time()) + host_ttl,
|
||||
))
|
||||
|
||||
servers.sort()
|
||||
cache[service_name] = list(servers)
|
||||
|
|
|
@ -69,8 +69,11 @@ class DnsTestCase(unittest.TestCase):
|
|||
|
||||
service_name = "test_service.examle.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
||||
cache = {
|
||||
service_name: [object()]
|
||||
service_name: [entry]
|
||||
}
|
||||
|
||||
servers = yield resolve_service(
|
||||
|
|
Loading…
Reference in a new issue