mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 15:01:23 +01:00
put resolve_service in an object
this makes it easier to stub things out for tests.
This commit is contained in:
parent
53a327b4d5
commit
7021784d46
3 changed files with 96 additions and 75 deletions
|
@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
|
||||||
from twisted.web.iweb import IAgent
|
from twisted.web.iweb import IAgent
|
||||||
|
|
||||||
from synapse.http.endpoint import parse_server_name
|
from synapse.http.endpoint import parse_server_name
|
||||||
from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service
|
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
|
||||||
from synapse.util.logcontext import make_deferred_yieldable
|
from synapse.util.logcontext import make_deferred_yieldable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reactor (IReactor): twisted reactor to use for underlying requests
|
reactor (IReactor): twisted reactor to use for underlying requests
|
||||||
|
|
||||||
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
tls_client_options_factory (ClientTLSOptionsFactory|None):
|
||||||
factory to use for fetching client tls options, or none to disable TLS.
|
factory to use for fetching client tls options, or none to disable TLS.
|
||||||
|
|
||||||
|
srv_resolver (SrvResolver|None):
|
||||||
|
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||||
|
implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reactor, tls_client_options_factory):
|
def __init__(
|
||||||
|
self, reactor, tls_client_options_factory, _srv_resolver=None,
|
||||||
|
):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
self._tls_client_options_factory = tls_client_options_factory
|
self._tls_client_options_factory = tls_client_options_factory
|
||||||
|
if _srv_resolver is None:
|
||||||
|
_srv_resolver = SrvResolver()
|
||||||
|
self._srv_resolver = _srv_resolver
|
||||||
|
|
||||||
self._pool = HTTPConnectionPool(reactor)
|
self._pool = HTTPConnectionPool(reactor)
|
||||||
self._pool.retryAutomatically = False
|
self._pool.retryAutomatically = False
|
||||||
|
@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
|
||||||
if port is not None:
|
if port is not None:
|
||||||
target = (host, port)
|
target = (host, port)
|
||||||
else:
|
else:
|
||||||
server_list = yield resolve_service(server_name_bytes)
|
server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
|
||||||
if not server_list:
|
if not server_list:
|
||||||
target = (host, 8448)
|
target = (host, 8448)
|
||||||
logger.debug("No SRV record for %s, using %s", host, target)
|
logger.debug("No SRV record for %s, using %s", host, target)
|
||||||
|
|
|
@ -84,73 +84,86 @@ def pick_server_from_list(server_list):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
class SrvResolver(object):
|
||||||
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
|
"""Interface to the dns client to do SRV lookups, with result caching.
|
||||||
"""Look up a SRV record, with caching
|
|
||||||
|
|
||||||
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
||||||
but the cache never gets populated), so we add our own caching layer here.
|
but the cache never gets populated), so we add our own caching layer here.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (bytes): record to look up
|
|
||||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||||
cache (dict): cache object
|
cache (dict): cache object
|
||||||
clock (object): clock implementation. must provide a time() method.
|
get_time (callable): clock implementation. Should return seconds since the epoch
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(service_name, bytes):
|
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
||||||
raise TypeError("%r is not a byte string" % (service_name,))
|
self._dns_client = dns_client
|
||||||
|
self._cache = cache
|
||||||
|
self._get_time = get_time
|
||||||
|
|
||||||
cache_entry = cache.get(service_name, None)
|
@defer.inlineCallbacks
|
||||||
if cache_entry:
|
def resolve_service(self, service_name):
|
||||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
"""Look up a SRV record
|
||||||
servers = list(cache_entry)
|
|
||||||
defer.returnValue(servers)
|
|
||||||
|
|
||||||
try:
|
Args:
|
||||||
answers, _, _ = yield make_deferred_yieldable(
|
service_name (bytes): record to look up
|
||||||
dns_client.lookupService(service_name),
|
|
||||||
)
|
Returns:
|
||||||
except DNSNameError:
|
Deferred[list[Server]]:
|
||||||
# TODO: cache this. We can get the SOA out of the exception, and use
|
a list of the SRV records, or an empty list if none found
|
||||||
# the negative-TTL value.
|
"""
|
||||||
defer.returnValue([])
|
now = int(self._get_time())
|
||||||
except DomainError as e:
|
|
||||||
# We failed to resolve the name (other than a NameError)
|
if not isinstance(service_name, bytes):
|
||||||
# Try something in the cache, else rereaise
|
raise TypeError("%r is not a byte string" % (service_name,))
|
||||||
cache_entry = cache.get(service_name, None)
|
|
||||||
|
cache_entry = self._cache.get(service_name, None)
|
||||||
if cache_entry:
|
if cache_entry:
|
||||||
logger.warn(
|
if all(s.expires > now for s in cache_entry):
|
||||||
"Failed to resolve %r, falling back to cache. %r",
|
servers = list(cache_entry)
|
||||||
service_name, e
|
defer.returnValue(servers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
answers, _, _ = yield make_deferred_yieldable(
|
||||||
|
self._dns_client.lookupService(service_name),
|
||||||
)
|
)
|
||||||
defer.returnValue(list(cache_entry))
|
except DNSNameError:
|
||||||
else:
|
# TODO: cache this. We can get the SOA out of the exception, and use
|
||||||
raise e
|
# the negative-TTL value.
|
||||||
|
defer.returnValue([])
|
||||||
|
except DomainError as e:
|
||||||
|
# We failed to resolve the name (other than a NameError)
|
||||||
|
# Try something in the cache, else rereaise
|
||||||
|
cache_entry = self._cache.get(service_name, None)
|
||||||
|
if cache_entry:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to resolve %r, falling back to cache. %r",
|
||||||
|
service_name, e
|
||||||
|
)
|
||||||
|
defer.returnValue(list(cache_entry))
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
if (len(answers) == 1
|
if (len(answers) == 1
|
||||||
and answers[0].type == dns.SRV
|
and answers[0].type == dns.SRV
|
||||||
and answers[0].payload
|
and answers[0].payload
|
||||||
and answers[0].payload.target == dns.Name(b'.')):
|
and answers[0].payload.target == dns.Name(b'.')):
|
||||||
raise ConnectError("Service %s unavailable" % service_name)
|
raise ConnectError("Service %s unavailable" % service_name)
|
||||||
|
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
for answer in answers:
|
for answer in answers:
|
||||||
if answer.type != dns.SRV or not answer.payload:
|
if answer.type != dns.SRV or not answer.payload:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = answer.payload
|
payload = answer.payload
|
||||||
|
|
||||||
servers.append(Server(
|
servers.append(Server(
|
||||||
host=payload.target.name,
|
host=payload.target.name,
|
||||||
port=payload.port,
|
port=payload.port,
|
||||||
priority=payload.priority,
|
priority=payload.priority,
|
||||||
weight=payload.weight,
|
weight=payload.weight,
|
||||||
expires=int(clock.time()) + answer.ttl,
|
expires=now + answer.ttl,
|
||||||
))
|
))
|
||||||
|
|
||||||
cache[service_name] = list(servers)
|
self._cache[service_name] = list(servers)
|
||||||
defer.returnValue(servers)
|
defer.returnValue(servers)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.error import ConnectError
|
from twisted.internet.error import ConnectError
|
||||||
from twisted.names import dns, error
|
from twisted.names import dns, error
|
||||||
|
|
||||||
from synapse.http.federation.srv_resolver import resolve_service
|
from synapse.http.federation.srv_resolver import SrvResolver
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -43,13 +43,13 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock.lookupService.return_value = result_deferred
|
dns_client_mock.lookupService.return_value = result_deferred
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_lookup():
|
def do_lookup():
|
||||||
|
|
||||||
with LoggingContext("one") as ctx:
|
with LoggingContext("one") as ctx:
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
|
@ -89,10 +89,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
entry.expires = 0
|
entry.expires = 0
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
dns_client_mock.lookupService.assert_called_once_with(service_name)
|
||||||
|
|
||||||
|
@ -112,11 +111,12 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
entry.expires = 999999999
|
entry.expires = 999999999
|
||||||
|
|
||||||
cache = {service_name: [entry]}
|
cache = {service_name: [entry]}
|
||||||
|
resolver = SrvResolver(
|
||||||
servers = yield resolve_service(
|
dns_client=dns_client_mock, cache=cache, get_time=clock.time,
|
||||||
service_name, dns_client=dns_client_mock, cache=cache, clock=clock
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
servers = yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
self.assertFalse(dns_client_mock.lookupService.called)
|
self.assertFalse(dns_client_mock.lookupService.called)
|
||||||
|
|
||||||
self.assertEquals(len(servers), 1)
|
self.assertEquals(len(servers), 1)
|
||||||
|
@ -131,9 +131,10 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
with self.assertRaises(error.DNSServerError):
|
with self.assertRaises(error.DNSServerError):
|
||||||
yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
|
yield resolver.resolve_service(service_name)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_name_error(self):
|
def test_name_error(self):
|
||||||
|
@ -144,10 +145,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
service_name = b"test_service.example.com"
|
service_name = b"test_service.example.com"
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
servers = yield resolve_service(
|
servers = yield resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEquals(len(servers), 0)
|
self.assertEquals(len(servers), 0)
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
|
@ -162,10 +162,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
# returning a single "." should make the lookup fail with a ConenctError
|
# returning a single "." should make the lookup fail with a ConenctError
|
||||||
|
@ -187,10 +186,9 @@ class SrvResolverTestCase(unittest.TestCase):
|
||||||
dns_client_mock = Mock()
|
dns_client_mock = Mock()
|
||||||
dns_client_mock.lookupService.return_value = lookup_deferred
|
dns_client_mock.lookupService.return_value = lookup_deferred
|
||||||
cache = {}
|
cache = {}
|
||||||
|
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
|
||||||
|
|
||||||
resolve_d = resolve_service(
|
resolve_d = resolver.resolve_service(service_name)
|
||||||
service_name, dns_client=dns_client_mock, cache=cache
|
|
||||||
)
|
|
||||||
self.assertNoResult(resolve_d)
|
self.assertNoResult(resolve_d)
|
||||||
|
|
||||||
lookup_deferred.callback((
|
lookup_deferred.callback((
|
||||||
|
|
Loading…
Reference in a new issue