0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 16:32:24 +01:00

Refactor and bugfix for resove_service (#4427)

This commit is contained in:
Richard van der Hoff 2019-01-22 10:59:27 +00:00 committed by GitHub
parent 23b0813599
commit 33a55289cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 86 deletions

1
changelog.d/4427.misc Normal file
View file

@ -0,0 +1 @@
Refactor and cleanup for SRV record lookup

View file

@ -12,30 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import logging import logging
import random import random
import re import re
import time
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError from synapse.http.federation.srv_resolver import Server, resolve_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
def parse_server_name(server_name): def parse_server_name(server_name):
"""Split a server name into host/port parts. """Split a server name into host/port parts.
@ -165,12 +153,9 @@ class SRVClientEndpoint(object):
self.service_name = "_%s._%s.%s" % (service, protocol, domain) self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None: if default_port is not None:
self.default_server = _Server( self.default_server = Server(
host=domain, host=domain,
port=default_port, port=default_port,
priority=0,
weight=0,
expires=0,
) )
else: else:
self.default_server = None self.default_server = None
@ -240,57 +225,3 @@ class SRVClientEndpoint(object):
) )
connection = yield endpoint.connect(protocolFactory) connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection) defer.returnValue(connection)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
try:
answers, _, _ = yield dns_client.lookupService(service_name)
except DNSNameError:
defer.returnValue([])
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(_Server(
host=str(payload.target),
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + answer.ttl,
))
servers.sort()
cache[service_name] = list(servers)
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
servers = list(cache_entry)
else:
raise e
defer.returnValue(servers)

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import attr
from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
@attr.s
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
host (bytes): target hostname
port (int):
priority (int):
weight (int):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
weight = attr.ib(default=0)
expires = attr.ib(default=0)
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
"""Look up a SRV record, with caching
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
service_name (unicode|bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
clock (object): clock implementation. must provide a time() method.
Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
"""
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
# byteses; however they will obviously end up as separate entries in the cache. We
# should pick one form and stick with it.
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
dns_client.lookupService(service_name),
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
defer.returnValue([])
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = cache.get(service_name, None)
if cache_entry:
logger.warn(
"Failed to resolve %r, falling back to cache. %r",
service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
servers.append(Server(
host=payload.target.name,
port=payload.port,
priority=payload.priority,
weight=payload.weight,
expires=int(clock.time()) + answer.ttl,
))
servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
cache[service_name] = list(servers)
defer.returnValue(servers)

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,40 +17,63 @@
from mock import Mock from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectError
from twisted.names import dns, error from twisted.names import dns, error
from synapse.http.endpoint import resolve_service from synapse.http.federation.srv_resolver import resolve_service
from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock from tests.utils import MockClock
from . import unittest
class SrvResolverTestCase(unittest.TestCase):
@unittest.DEBUG
class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve(self): def test_resolve(self):
dns_client_mock = Mock() dns_client_mock = Mock()
service_name = "test_service.example.com" service_name = b"test_service.example.com"
host_name = "example.com" host_name = b"example.com"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, payload=dns.Record_SRV(target=host_name) type=dns.SRV, payload=dns.Record_SRV(target=host_name)
) )
dns_client_mock.lookupService.return_value = defer.succeed( result_deferred = Deferred()
([answer_srv], None, None) dns_client_mock.lookupService.return_value = result_deferred
)
cache = {} cache = {}
servers = yield resolve_service( @defer.inlineCallbacks
def do_lookup():
with LoggingContext("one") as ctx:
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache service_name, dns_client=dns_client_mock, cache=cache
) )
self.assertNoResult(resolve_d)
# should have reset to the sentinel context
self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
result = yield resolve_d
# should have restored our context
self.assertIs(LoggingContext.current_context(), ctx)
defer.returnValue(result)
test_d = do_lookup()
self.assertNoResult(test_d)
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
result_deferred.callback(
([answer_srv], None, None)
)
servers = self.successResultOf(test_d)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, host_name) self.assertEquals(servers[0].host, host_name)
@ -127,3 +151,59 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(len(servers), 0) self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
def test_disabled_service(self):
"""
test the behaviour when there is a single record which is ".".
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d)
# returning a single "." should make the lookup fail with a ConenctError
lookup_deferred.callback((
[dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
None,
None,
))
self.failureResultOf(resolve_d, ConnectError)
def test_non_srv_answer(self):
"""
test the behaviour when the dns server gives us a spurious non-SRV response
"""
service_name = b"test_service.example.com"
lookup_deferred = Deferred()
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = lookup_deferred
cache = {}
resolve_d = resolve_service(
service_name, dns_client=dns_client_mock, cache=cache
)
self.assertNoResult(resolve_d)
lookup_deferred.callback((
[
dns.RRHeader(type=dns.A, payload=dns.Record_A()),
dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
],
None,
None,
))
servers = self.successResultOf(resolve_d)
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, b"host")