0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-25 05:58:21 +02:00

Add types to http.site (#10867)

This commit is contained in:
Erik Johnston 2021-09-21 17:41:27 +01:00 committed by GitHub
parent ebd8baf61f
commit b25a494779
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 18 deletions

1
changelog.d/10867.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.http.site`.

View file

@ -21,7 +21,7 @@ from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
@ -83,13 +83,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
self.finish_time: Optional[float] = None
def __repr__(self):
def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@ -100,7 +100,7 @@ class SynapseRequest(Request):
self.site.site_tag,
)
def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@ -139,7 +139,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self):
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@ -205,7 +205,7 @@ class SynapseRequest(Request):
return None, None
def render(self, resrc):
def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@ -282,7 +282,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
def finish(self):
def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@ -295,7 +295,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@ -327,7 +327,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name):
def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@ -354,9 +354,11 @@ class SynapseRequest(Request):
self.get_redacted_uri(),
)
def _finished_processing(self):
def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
def requestReceived(self, command, path, version):
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
def isSecure(self):
def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@ -545,14 +547,16 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request:
def request_factory(channel, queued: bool) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
channel,
max_request_body_size=max_request_body_size,
queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
def log(self, request: SynapseRequest) -> None:
pass