mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 07:21:37 +01:00
Require that service_name be a byte string
it is only ever a bytes now, so let's enforce that.
This commit is contained in:
parent
c66f4bf7f1
commit
53a327b4d5
2 changed files with 8 additions and 8 deletions
|
@ -92,7 +92,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|||
but the cache never gets populated), so we add our own caching layer here.
|
||||
|
||||
Args:
|
||||
service_name (unicode|bytes): record to look up
|
||||
service_name (bytes): record to look up
|
||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||
cache (dict): cache object
|
||||
clock (object): clock implementation. must provide a time() method.
|
||||
|
@ -100,9 +100,9 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|||
Returns:
|
||||
Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
|
||||
"""
|
||||
# TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
|
||||
# byteses; however they will obviously end up as separate entries in the cache. We
|
||||
# should pick one form and stick with it.
|
||||
if not isinstance(service_name, bytes):
|
||||
raise TypeError("%r is not a byte string" % (service_name,))
|
||||
|
||||
cache_entry = cache.get(service_name, None)
|
||||
if cache_entry:
|
||||
if all(s.expires > int(clock.time()) for s in cache_entry):
|
||||
|
|
|
@ -83,7 +83,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock()
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 0
|
||||
|
@ -106,7 +106,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
dns_client_mock = Mock(spec_set=['lookupService'])
|
||||
dns_client_mock.lookupService = Mock(spec_set=[])
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
entry = Mock(spec_set=["expires"])
|
||||
entry.expires = 999999999
|
||||
|
@ -128,7 +128,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
|
||||
|
@ -141,7 +141,7 @@ class SrvResolverTestCase(unittest.TestCase):
|
|||
|
||||
dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
|
||||
|
||||
service_name = "test_service.example.com"
|
||||
service_name = b"test_service.example.com"
|
||||
|
||||
cache = {}
|
||||
|
||||
|
|
Loading…
Reference in a new issue