Require that service_name be a byte string

it is only ever a bytes now, so let's enforce that.
This commit is contained in:
Richard van der Hoff 2019-01-22 17:35:09 +00:00
parent c66f4bf7f1
commit 53a327b4d5
2 changed files with 8 additions and 8 deletions

View file

@ -92,7 +92,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
but the cache never gets populated), so we add our own caching layer here. but the cache never gets populated), so we add our own caching layer here.
Args: Args:
service_name (unicode|bytes): record to look up service_name (bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object cache (dict): cache object
clock (object): clock implementation. must provide a time() method. clock (object): clock implementation. must provide a time() method.
@ -100,9 +100,9 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
Returns: Returns:
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
""" """
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded if not isinstance(service_name, bytes):
# byteses; however they will obviously end up as separate entries in the cache. We raise TypeError("%r is not a byte string" % (service_name,))
# should pick one form and stick with it.
cache_entry = cache.get(service_name, None) cache_entry = cache.get(service_name, None)
if cache_entry: if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry): if all(s.expires > int(clock.time()) for s in cache_entry):

View file

@ -83,7 +83,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock() dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 0 entry.expires = 0
@ -106,7 +106,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock = Mock(spec_set=['lookupService']) dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[]) dns_client_mock.lookupService = Mock(spec_set=[])
service_name = "test_service.example.com" service_name = b"test_service.example.com"
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 999999999 entry.expires = 999999999
@ -128,7 +128,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}
@ -141,7 +141,7 @@ class SrvResolverTestCase(unittest.TestCase):
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
service_name = "test_service.example.com" service_name = b"test_service.example.com"
cache = {} cache = {}