mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 19:53:51 +01:00
Additional type hints for the proxy agent and SRV resolver modules. (#10608)
This commit is contained in:
parent
78a70a2e0b
commit
0c3565da4c
5 changed files with 41 additions and 25 deletions
1
changelog.d/10608.misc
Normal file
1
changelog.d/10608.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel.
|
3
mypy.ini
3
mypy.ini
|
@ -28,10 +28,13 @@ files =
|
||||||
synapse/federation,
|
synapse/federation,
|
||||||
synapse/groups,
|
synapse/groups,
|
||||||
synapse/handlers,
|
synapse/handlers,
|
||||||
|
synapse/http/additional_resource.py,
|
||||||
synapse/http/client.py,
|
synapse/http/client.py,
|
||||||
synapse/http/federation/matrix_federation_agent.py,
|
synapse/http/federation/matrix_federation_agent.py,
|
||||||
|
synapse/http/federation/srv_resolver.py,
|
||||||
synapse/http/federation/well_known_resolver.py,
|
synapse/http/federation/well_known_resolver.py,
|
||||||
synapse/http/matrixfederationclient.py,
|
synapse/http/matrixfederationclient.py,
|
||||||
|
synapse/http/proxyagent.py,
|
||||||
synapse/http/servlet.py,
|
synapse/http/servlet.py,
|
||||||
synapse/http/server.py,
|
synapse/http/server.py,
|
||||||
synapse/http/site.py,
|
synapse/http/site.py,
|
||||||
|
|
|
@ -12,8 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource
|
from synapse.http.server import DirectServeJsonResource
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class AdditionalResource(DirectServeJsonResource):
|
class AdditionalResource(DirectServeJsonResource):
|
||||||
"""Resource wrapper for additional_resources
|
"""Resource wrapper for additional_resources
|
||||||
|
@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
|
||||||
and exception handling.
|
and exception handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, handler):
|
def __init__(self, hs: "HomeServer", handler):
|
||||||
"""Initialise AdditionalResource
|
"""Initialise AdditionalResource
|
||||||
|
|
||||||
The ``handler`` should return a deferred which completes when it has
|
The ``handler`` should return a deferred which completes when it has
|
||||||
|
@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
|
||||||
``request.write()``, and call ``request.finish()``.
|
``request.write()``, and call ``request.finish()``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer): homeserver
|
hs: homeserver
|
||||||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||||
function to be called to handle the request.
|
function to be called to handle the request.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._handler = handler
|
self._handler = handler
|
||||||
|
|
||||||
def _async_render(self, request):
|
def _async_render(self, request: Request):
|
||||||
# Cheekily pass the result straight through, so we don't need to worry
|
# Cheekily pass the result straight through, so we don't need to worry
|
||||||
# if its an awaitable or not.
|
# if its an awaitable or not.
|
||||||
return self._handler(request)
|
return self._handler(request)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SERVER_CACHE = {}
|
SERVER_CACHE: Dict[bytes, List["Server"]] = {}
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||||
class Server:
|
class Server:
|
||||||
"""
|
"""
|
||||||
Our record of an individual server which can be tried to reach a destination.
|
Our record of an individual server which can be tried to reach a destination.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
host (bytes): target hostname
|
host: target hostname
|
||||||
port (int):
|
port:
|
||||||
priority (int):
|
priority:
|
||||||
weight (int):
|
weight:
|
||||||
expires (int): when the cache should expire this record - in *seconds* since
|
expires: when the cache should expire this record - in *seconds* since
|
||||||
the epoch
|
the epoch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
host = attr.ib()
|
host: bytes
|
||||||
port = attr.ib()
|
port: int
|
||||||
priority = attr.ib(default=0)
|
priority: int = 0
|
||||||
weight = attr.ib(default=0)
|
weight: int = 0
|
||||||
expires = attr.ib(default=0)
|
expires: int = 0
|
||||||
|
|
||||||
|
|
||||||
def _sort_server_list(server_list):
|
def _sort_server_list(server_list: List[Server]) -> List[Server]:
|
||||||
"""Given a list of SRV records sort them into priority order and shuffle
|
"""Given a list of SRV records sort them into priority order and shuffle
|
||||||
each priority with the given weight.
|
each priority with the given weight.
|
||||||
"""
|
"""
|
||||||
priority_map = {}
|
priority_map: Dict[int, List[Server]] = {}
|
||||||
|
|
||||||
for server in server_list:
|
for server in server_list:
|
||||||
priority_map.setdefault(server.priority, []).append(server)
|
priority_map.setdefault(server.priority, []).append(server)
|
||||||
|
@ -103,11 +103,16 @@ class SrvResolver:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
|
||||||
cache (dict): cache object
|
cache: cache object
|
||||||
get_time (callable): clock implementation. Should return seconds since the epoch
|
get_time: clock implementation. Should return seconds since the epoch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dns_client=client,
|
||||||
|
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
|
||||||
|
get_time: Callable[[], float] = time.time,
|
||||||
|
):
|
||||||
self._dns_client = dns_client
|
self._dns_client = dns_client
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._get_time = get_time
|
self._get_time = get_time
|
||||||
|
@ -116,7 +121,7 @@ class SrvResolver:
|
||||||
"""Look up a SRV record
|
"""Look up a SRV record
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_name (bytes): record to look up
|
service_name: record to look up
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of the SRV records, or an empty list if none found
|
a list of the SRV records, or an empty list if none found
|
||||||
|
@ -158,7 +163,7 @@ class SrvResolver:
|
||||||
and answers[0].payload
|
and answers[0].payload
|
||||||
and answers[0].payload.target == dns.Name(b".")
|
and answers[0].payload.target == dns.Name(b".")
|
||||||
):
|
):
|
||||||
raise ConnectError("Service %s unavailable" % service_name)
|
raise ConnectError(f"Service {service_name!r} unavailable")
|
||||||
|
|
||||||
servers = []
|
servers = []
|
||||||
|
|
||||||
|
|
|
@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
|
||||||
raise ValueError(f"Invalid URI {uri!r}")
|
raise ValueError(f"Invalid URI {uri!r}")
|
||||||
|
|
||||||
parsed_uri = URI.fromBytes(uri)
|
parsed_uri = URI.fromBytes(uri)
|
||||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
|
||||||
request_path = parsed_uri.originForm
|
request_path = parsed_uri.originForm
|
||||||
|
|
||||||
should_skip_proxy = False
|
should_skip_proxy = False
|
||||||
|
@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
|
||||||
)
|
)
|
||||||
# Cache *all* connections under the same key, since we are only
|
# Cache *all* connections under the same key, since we are only
|
||||||
# connecting to a single destination, the proxy:
|
# connecting to a single destination, the proxy:
|
||||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
pool_key = "http-proxy"
|
||||||
endpoint = self.http_proxy_endpoint
|
endpoint = self.http_proxy_endpoint
|
||||||
request_path = uri
|
request_path = uri
|
||||||
elif (
|
elif (
|
||||||
|
|
Loading…
Reference in a new issue