mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 16:13:50 +01:00
Merge pull request #677 from matrix-org/erikj/dns_cache
Read from DNS cache if within TTL
This commit is contained in:
commit
79fc4ff6f9
2 changed files with 55 additions and 18 deletions
|
@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
|
|||
import collections
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,7 +32,7 @@ SERVER_CACHE = {}
|
|||
|
||||
|
||||
_Server = collections.namedtuple(
|
||||
"_Server", "priority weight host port"
|
||||
"_Server", "priority weight host port expires"
|
||||
)
|
||||
|
||||
|
||||
|
@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
|
|||
host=domain,
|
||||
port=default_port,
|
||||
priority=0,
|
||||
weight=0
|
||||
weight=0,
|
||||
expires=0,
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||
servers = list(cache_entry)
|
||||
defer.returnValue(servers)
|
||||
|
||||
servers = []
|
||||
|
||||
try:
|
||||
|
@ -173,26 +181,25 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
|
|||
continue
|
||||
|
||||
payload = answer.payload
|
||||
|
||||
host = str(payload.target)
|
||||
srv_ttl = answer.ttl
|
||||
|
||||
try:
|
||||
answers, _, _ = yield dns_client.lookupAddress(host)
|
||||
except DNSNameError:
|
||||
continue
|
||||
|
||||
ips = [
|
||||
answer.payload.dottedQuad()
|
||||
for answer in answers
|
||||
if answer.type == dns.A and answer.payload
|
||||
]
|
||||
for answer in answers:
|
||||
if answer.type == dns.A and answer.payload:
|
||||
ip = answer.payload.dottedQuad()
|
||||
host_ttl = min(srv_ttl, answer.ttl)
|
||||
|
||||
for ip in ips:
|
||||
servers.append(_Server(
|
||||
host=ip,
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight)
|
||||
weight=int(payload.weight),
|
||||
expires=int(clock.time()) + host_ttl,
|
||||
))
|
||||
|
||||
servers.sort()
|
||||
|
|
|
@ -21,6 +21,8 @@ from mock import Mock
|
|||
|
||||
from synapse.http.endpoint import resolve_service
|
||||
|
||||
from tests.utils import MockClock
|
||||
|
||||
|
||||
class DnsTestCase(unittest.TestCase):
|
||||
|
||||
|
@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
|
|||
self.assertEquals(servers[0].host, ip_address)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache(self):
|
||||
def test_from_cache_expired_and_dns_fail(self):
|
||||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.examle.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 0
|
||||
|
||||
cache = {
|
||||
service_name: [object()]
|
||||
service_name: [entry]
|
||||
}
|
||||
|
||||
servers = yield resolve_service(
|
||||
|
@ -82,6 +87,31 @@ class DnsTestCase(unittest.TestCase):
|
|||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_from_cache(self):
|
||||
clock = MockClock()
|
||||
|
||||
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||
|
||||
service_name = "test_service.examle.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
||||
cache = {
|
||||
service_name: [entry]
|
||||
}
|
||||
|
||||
servers = yield resolve_service(
|
||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
|
||||
)
|
||||
|
||||
self.assertFalse(dns_client_mock.lookupService.called)
|
||||
|
||||
self.assertEquals(len(servers), 1)
|
||||
self.assertEquals(servers, cache[service_name])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_empty_cache(self):
|
||||
dns_client_mock = Mock()
|
||||
|
|
Loading…
Reference in a new issue