forked from MirrorHub/synapse
Support for routing outbound HTTP requests via a proxy (#6239)
The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy. The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`. The proxy will then be used for * push * url previews * phone-home stats * recaptcha validation * CAS auth validation It will *not* be used for: * Application Services * Identity servers * Outbound federation * In worker configurations, connections from workers to masters Fixes #4198.
This commit is contained in:
parent
fe1f2b4520
commit
1cb84c6486
16 changed files with 812 additions and 12 deletions
1
changelog.d/6238.feature
Normal file
1
changelog.d/6238.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
|
|
@ -565,7 +565,7 @@ def run(hs):
|
||||||
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
"Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
yield hs.get_simple_http_client().put_json(
|
yield hs.get_proxied_http_client().put_json(
|
||||||
hs.config.report_stats_endpoint, stats
|
hs.config.report_stats_endpoint, stats
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._enabled = bool(hs.config.recaptcha_private_key)
|
self._enabled = bool(hs.config.recaptcha_private_key)
|
||||||
self._http_client = hs.get_simple_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
self._url = hs.config.recaptcha_siteverify_api
|
self._url = hs.config.recaptcha_siteverify_api
|
||||||
self._secret = hs.config.recaptcha_private_key
|
self._secret = hs.config.recaptcha_private_key
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,7 @@ from synapse.http import (
|
||||||
cancelled_to_request_timed_out_error,
|
cancelled_to_request_timed_out_error,
|
||||||
redact_uri,
|
redact_uri,
|
||||||
)
|
)
|
||||||
|
from synapse.http.proxyagent import ProxyAgent
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import timeout_deferred
|
||||||
|
@ -183,7 +184,15 @@ class SimpleHttpClient(object):
|
||||||
using HTTP in Matrix
|
using HTTP in Matrix
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs,
|
||||||
|
treq_args={},
|
||||||
|
ip_whitelist=None,
|
||||||
|
ip_blacklist=None,
|
||||||
|
http_proxy=None,
|
||||||
|
https_proxy=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer)
|
hs (synapse.server.HomeServer)
|
||||||
|
@ -192,6 +201,8 @@ class SimpleHttpClient(object):
|
||||||
we may not request.
|
we may not request.
|
||||||
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
|
||||||
request if it were otherwise caught in a blacklist.
|
request if it were otherwise caught in a blacklist.
|
||||||
|
http_proxy (bytes): proxy server to use for http connections. host[:port]
|
||||||
|
https_proxy (bytes): proxy server to use for https connections. host[:port]
|
||||||
"""
|
"""
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
|
@ -236,11 +247,13 @@ class SimpleHttpClient(object):
|
||||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||||
# 'like a browser'
|
# 'like a browser'
|
||||||
self.agent = Agent(
|
self.agent = ProxyAgent(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
connectTimeout=15,
|
connectTimeout=15,
|
||||||
contextFactory=self.hs.get_http_client_context_factory(),
|
contextFactory=self.hs.get_http_client_context_factory(),
|
||||||
pool=pool,
|
pool=pool,
|
||||||
|
http_proxy=http_proxy,
|
||||||
|
https_proxy=https_proxy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._ip_blacklist:
|
if self._ip_blacklist:
|
||||||
|
|
195
synapse/http/connectproxyclient.py
Normal file
195
synapse/http/connectproxyclient.py
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer, protocol
|
||||||
|
from twisted.internet.error import ConnectError
|
||||||
|
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||||
|
from twisted.internet.protocol import connectionDone
|
||||||
|
from twisted.web import http
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConnectError(ConnectError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IStreamClientEndpoint)
|
||||||
|
class HTTPConnectProxyEndpoint(object):
|
||||||
|
"""An Endpoint implementation which will send a CONNECT request to an http proxy
|
||||||
|
|
||||||
|
Wraps an existing HostnameEndpoint for the proxy.
|
||||||
|
|
||||||
|
When we get the connect() request from the connection pool (via the TLS wrapper),
|
||||||
|
we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
|
||||||
|
CONNECT request. Once that completes, we invoke the protocolFactory which was passed
|
||||||
|
in.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor: the Twisted reactor to use for the connection
|
||||||
|
proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
|
||||||
|
proxy
|
||||||
|
host (bytes): hostname that we want to CONNECT to
|
||||||
|
port (int): port that we want to connect to
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reactor, proxy_endpoint, host, port):
|
||||||
|
self._reactor = reactor
|
||||||
|
self._proxy_endpoint = proxy_endpoint
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
|
||||||
|
|
||||||
|
def connect(self, protocolFactory):
|
||||||
|
f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
|
||||||
|
d = self._proxy_endpoint.connect(f)
|
||||||
|
# once the tcp socket connects successfully, we need to wait for the
|
||||||
|
# CONNECT to complete.
|
||||||
|
d.addCallback(lambda conn: f.on_connection)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPProxiedClientFactory(protocol.ClientFactory):
|
||||||
|
"""ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
|
||||||
|
|
||||||
|
Once the CONNECT completes, invokes the original ClientFactory to build the
|
||||||
|
HTTP Protocol object and run the rest of the connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_host (bytes): hostname that we want to CONNECT to
|
||||||
|
dst_port (int): port that we want to connect to
|
||||||
|
wrapped_factory (protocol.ClientFactory): The original Factory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dst_host, dst_port, wrapped_factory):
|
||||||
|
self.dst_host = dst_host
|
||||||
|
self.dst_port = dst_port
|
||||||
|
self.wrapped_factory = wrapped_factory
|
||||||
|
self.on_connection = defer.Deferred()
|
||||||
|
|
||||||
|
def startedConnecting(self, connector):
|
||||||
|
return self.wrapped_factory.startedConnecting(connector)
|
||||||
|
|
||||||
|
def buildProtocol(self, addr):
|
||||||
|
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
|
||||||
|
|
||||||
|
return HTTPConnectProtocol(
|
||||||
|
self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
def clientConnectionFailed(self, connector, reason):
|
||||||
|
logger.debug("Connection to proxy failed: %s", reason)
|
||||||
|
if not self.on_connection.called:
|
||||||
|
self.on_connection.errback(reason)
|
||||||
|
return self.wrapped_factory.clientConnectionFailed(connector, reason)
|
||||||
|
|
||||||
|
def clientConnectionLost(self, connector, reason):
|
||||||
|
logger.debug("Connection to proxy lost: %s", reason)
|
||||||
|
if not self.on_connection.called:
|
||||||
|
self.on_connection.errback(reason)
|
||||||
|
return self.wrapped_factory.clientConnectionLost(connector, reason)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPConnectProtocol(protocol.Protocol):
|
||||||
|
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
|
||||||
|
to put in the CONNECT request
|
||||||
|
|
||||||
|
port (int): The original HTTP(s) port to put in the CONNECT request
|
||||||
|
|
||||||
|
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
|
||||||
|
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
|
||||||
|
|
||||||
|
connected_deferred (Deferred): a Deferred which will be callbacked with
|
||||||
|
wrapped_protocol when the CONNECT completes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host, port, wrapped_protocol, connected_deferred):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.wrapped_protocol = wrapped_protocol
|
||||||
|
self.connected_deferred = connected_deferred
|
||||||
|
self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
|
||||||
|
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
self.http_setup_client.makeConnection(self.transport)
|
||||||
|
|
||||||
|
def connectionLost(self, reason=connectionDone):
|
||||||
|
if self.wrapped_protocol.connected:
|
||||||
|
self.wrapped_protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
self.http_setup_client.connectionLost(reason)
|
||||||
|
|
||||||
|
if not self.connected_deferred.called:
|
||||||
|
self.connected_deferred.errback(reason)
|
||||||
|
|
||||||
|
def proxyConnected(self, _):
|
||||||
|
self.wrapped_protocol.makeConnection(self.transport)
|
||||||
|
|
||||||
|
self.connected_deferred.callback(self.wrapped_protocol)
|
||||||
|
|
||||||
|
# Get any pending data from the http buf and forward it to the original protocol
|
||||||
|
buf = self.http_setup_client.clearLineBuffer()
|
||||||
|
if buf:
|
||||||
|
self.wrapped_protocol.dataReceived(buf)
|
||||||
|
|
||||||
|
def dataReceived(self, data):
|
||||||
|
# if we've set up the HTTP protocol, we can send the data there
|
||||||
|
if self.wrapped_protocol.connected:
|
||||||
|
return self.wrapped_protocol.dataReceived(data)
|
||||||
|
|
||||||
|
# otherwise, we must still be setting up the connection: send the data to the
|
||||||
|
# setup client
|
||||||
|
return self.http_setup_client.dataReceived(data)
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPConnectSetupClient(http.HTTPClient):
|
||||||
|
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host (bytes): The hostname to send in the CONNECT message
|
||||||
|
port (int): The port to send in the CONNECT message
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host, port):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.on_connected = defer.Deferred()
|
||||||
|
|
||||||
|
def connectionMade(self):
|
||||||
|
logger.debug("Connected to proxy, sending CONNECT")
|
||||||
|
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
|
||||||
|
self.endHeaders()
|
||||||
|
|
||||||
|
def handleStatus(self, version, status, message):
|
||||||
|
logger.debug("Got Status: %s %s %s", status, message, version)
|
||||||
|
if status != b"200":
|
||||||
|
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|
||||||
|
|
||||||
|
def handleEndHeaders(self):
|
||||||
|
logger.debug("End Headers")
|
||||||
|
self.on_connected.callback(None)
|
||||||
|
|
||||||
|
def handleResponse(self, body):
|
||||||
|
pass
|
195
synapse/http/proxyagent.py
Normal file
195
synapse/http/proxyagent.py
Normal file
|
@ -0,0 +1,195 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
|
||||||
|
from twisted.web.error import SchemeNotSupported
|
||||||
|
from twisted.web.iweb import IAgent
|
||||||
|
|
||||||
|
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IAgent)
|
||||||
|
class ProxyAgent(_AgentBase):
|
||||||
|
"""An Agent implementation which will use an HTTP proxy if one was requested
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reactor: twisted reactor to place outgoing
|
||||||
|
connections.
|
||||||
|
|
||||||
|
contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
|
||||||
|
verification parameters of OpenSSL. The default is to use a
|
||||||
|
`BrowserLikePolicyForHTTPS`, so unless you have special
|
||||||
|
requirements you can leave this as-is.
|
||||||
|
|
||||||
|
connectTimeout (float): The amount of time that this Agent will wait
|
||||||
|
for the peer to accept a connection.
|
||||||
|
|
||||||
|
bindAddress (bytes): The local address for client sockets to bind to.
|
||||||
|
|
||||||
|
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||||
|
non-persistent pool instance will be created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reactor,
|
||||||
|
contextFactory=BrowserLikePolicyForHTTPS(),
|
||||||
|
connectTimeout=None,
|
||||||
|
bindAddress=None,
|
||||||
|
pool=None,
|
||||||
|
http_proxy=None,
|
||||||
|
https_proxy=None,
|
||||||
|
):
|
||||||
|
_AgentBase.__init__(self, reactor, pool)
|
||||||
|
|
||||||
|
self._endpoint_kwargs = {}
|
||||||
|
if connectTimeout is not None:
|
||||||
|
self._endpoint_kwargs["timeout"] = connectTimeout
|
||||||
|
if bindAddress is not None:
|
||||||
|
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||||
|
|
||||||
|
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||||
|
http_proxy, reactor, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.https_proxy_endpoint = _http_proxy_endpoint(
|
||||||
|
https_proxy, reactor, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._policy_for_https = contextFactory
|
||||||
|
self._reactor = reactor
|
||||||
|
|
||||||
|
def request(self, method, uri, headers=None, bodyProducer=None):
|
||||||
|
"""
|
||||||
|
Issue a request to the server indicated by the given uri.
|
||||||
|
|
||||||
|
Supports `http` and `https` schemes.
|
||||||
|
|
||||||
|
An existing connection from the connection pool may be used or a new one may be
|
||||||
|
created.
|
||||||
|
|
||||||
|
See also: twisted.web.iweb.IAgent.request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (bytes): The request method to use, such as `GET`, `POST`, etc
|
||||||
|
|
||||||
|
uri (bytes): The location of the resource to request.
|
||||||
|
|
||||||
|
headers (Headers|None): Extra headers to send with the request
|
||||||
|
|
||||||
|
bodyProducer (IBodyProducer|None): An object which can generate bytes to
|
||||||
|
make up the body of this request (for example, the properly encoded
|
||||||
|
contents of a file for a file upload). Or, None if the request is to
|
||||||
|
have no body.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[IResponse]: completes when the header of the response has
|
||||||
|
been received (regardless of the response status code).
|
||||||
|
"""
|
||||||
|
uri = uri.strip()
|
||||||
|
if not _VALID_URI.match(uri):
|
||||||
|
raise ValueError("Invalid URI {!r}".format(uri))
|
||||||
|
|
||||||
|
parsed_uri = URI.fromBytes(uri)
|
||||||
|
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||||
|
request_path = parsed_uri.originForm
|
||||||
|
|
||||||
|
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||||
|
# Cache *all* connections under the same key, since we are only
|
||||||
|
# connecting to a single destination, the proxy:
|
||||||
|
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||||
|
endpoint = self.http_proxy_endpoint
|
||||||
|
request_path = uri
|
||||||
|
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||||
|
endpoint = HTTPConnectProxyEndpoint(
|
||||||
|
self._reactor,
|
||||||
|
self.https_proxy_endpoint,
|
||||||
|
parsed_uri.host,
|
||||||
|
parsed_uri.port,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not using a proxy
|
||||||
|
endpoint = HostnameEndpoint(
|
||||||
|
self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Requesting %s via %s", uri, endpoint)
|
||||||
|
|
||||||
|
if parsed_uri.scheme == b"https":
|
||||||
|
tls_connection_creator = self._policy_for_https.creatorForNetloc(
|
||||||
|
parsed_uri.host, parsed_uri.port
|
||||||
|
)
|
||||||
|
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
|
||||||
|
elif parsed_uri.scheme == b"http":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return defer.fail(
|
||||||
|
Failure(
|
||||||
|
SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._requestWithEndpoint(
|
||||||
|
pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _http_proxy_endpoint(proxy, reactor, **kwargs):
|
||||||
|
"""Parses an http proxy setting and returns an endpoint for the proxy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy (bytes|None): the proxy setting
|
||||||
|
reactor: reactor to be used to connect to the proxy
|
||||||
|
kwargs: other args to be passed to HostnameEndpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
|
||||||
|
or None
|
||||||
|
"""
|
||||||
|
if proxy is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# currently we only support hostname:port. Some apps also support
|
||||||
|
# protocol://<host>[:port], which allows a way of requiring a TLS connection to the
|
||||||
|
# proxy.
|
||||||
|
|
||||||
|
host, port = parse_host_port(proxy, default_port=1080)
|
||||||
|
return HostnameEndpoint(reactor, host, port, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_host_port(hostport, default_port=None):
|
||||||
|
# could have sworn we had one of these somewhere else...
|
||||||
|
if b":" in hostport:
|
||||||
|
host, port = hostport.rsplit(b":", 1)
|
||||||
|
try:
|
||||||
|
port = int(port)
|
||||||
|
return host, port
|
||||||
|
except ValueError:
|
||||||
|
# the thing after the : wasn't a valid port; presumably this is an
|
||||||
|
# IPv6 address.
|
||||||
|
pass
|
||||||
|
|
||||||
|
return hostport, default_port
|
|
@ -103,7 +103,7 @@ class HttpPusher(object):
|
||||||
if "url" not in self.data:
|
if "url" not in self.data:
|
||||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||||
self.url = self.data["url"]
|
self.url = self.data["url"]
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_proxied_http_client()
|
||||||
self.data_minus_url = {}
|
self.data_minus_url = {}
|
||||||
self.data_minus_url.update(self.data)
|
self.data_minus_url.update(self.data)
|
||||||
del self.data_minus_url["url"]
|
del self.data_minus_url["url"]
|
||||||
|
|
|
@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet):
|
||||||
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
self.cas_displayname_attribute = hs.config.cas_displayname_attribute
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
self._http_client = hs.get_simple_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
|
@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
|
||||||
treq_args={"browser_like_redirects": True},
|
treq_args={"browser_like_redirects": True},
|
||||||
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
||||||
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
||||||
|
http_proxy=os.getenv("http_proxy"),
|
||||||
|
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||||
)
|
)
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.primary_base_path = media_repo.primary_base_path
|
self.primary_base_path = media_repo.primary_base_path
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
# Imports required for the default HomeServer() implementation
|
# Imports required for the default HomeServer() implementation
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
from twisted.mail.smtp import sendmail
|
from twisted.mail.smtp import sendmail
|
||||||
|
@ -168,6 +169,7 @@ class HomeServer(object):
|
||||||
"filtering",
|
"filtering",
|
||||||
"http_client_context_factory",
|
"http_client_context_factory",
|
||||||
"simple_http_client",
|
"simple_http_client",
|
||||||
|
"proxied_http_client",
|
||||||
"media_repository",
|
"media_repository",
|
||||||
"media_repository_resource",
|
"media_repository_resource",
|
||||||
"federation_transport_client",
|
"federation_transport_client",
|
||||||
|
@ -311,6 +313,13 @@ class HomeServer(object):
|
||||||
def build_simple_http_client(self):
|
def build_simple_http_client(self):
|
||||||
return SimpleHttpClient(self)
|
return SimpleHttpClient(self)
|
||||||
|
|
||||||
|
def build_proxied_http_client(self):
|
||||||
|
return SimpleHttpClient(
|
||||||
|
self,
|
||||||
|
http_proxy=os.getenv("http_proxy"),
|
||||||
|
https_proxy=os.getenv("HTTPS_PROXY"),
|
||||||
|
)
|
||||||
|
|
||||||
def build_room_creation_handler(self):
|
def build_room_creation_handler(self):
|
||||||
return RoomCreationHandler(self)
|
return RoomCreationHandler(self)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import synapse.handlers.message
|
||||||
import synapse.handlers.room
|
import synapse.handlers.room
|
||||||
import synapse.handlers.room_member
|
import synapse.handlers.room_member
|
||||||
import synapse.handlers.set_password
|
import synapse.handlers.set_password
|
||||||
|
import synapse.http.client
|
||||||
import synapse.rest.media.v1.media_repository
|
import synapse.rest.media.v1.media_repository
|
||||||
import synapse.server_notices.server_notices_manager
|
import synapse.server_notices.server_notices_manager
|
||||||
import synapse.server_notices.server_notices_sender
|
import synapse.server_notices.server_notices_sender
|
||||||
|
@ -38,6 +39,14 @@ class HomeServer(object):
|
||||||
pass
|
pass
|
||||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||||
pass
|
pass
|
||||||
|
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||||
|
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||||
|
or support any HTTP_PROXY settings"""
|
||||||
|
pass
|
||||||
|
def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
|
||||||
|
"""Fetch an HTTP client implementation which doesn't do any blacklisting
|
||||||
|
but does support HTTP_PROXY settings"""
|
||||||
|
pass
|
||||||
def get_deactivate_account_handler(
|
def get_deactivate_account_handler(
|
||||||
self,
|
self,
|
||||||
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||||
|
|
|
@ -20,6 +20,23 @@ from zope.interface import implementer
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from OpenSSL.SSL import Connection
|
from OpenSSL.SSL import Connection
|
||||||
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
|
||||||
|
from twisted.internet.ssl import Certificate, trustRootFromCertificates
|
||||||
|
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
|
||||||
|
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_https_policy():
|
||||||
|
"""Get a test IPolicyForHTTPS which trusts the test CA cert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IPolicyForHTTPS
|
||||||
|
"""
|
||||||
|
ca_file = get_test_ca_cert_file()
|
||||||
|
with open(ca_file) as stream:
|
||||||
|
content = stream.read()
|
||||||
|
cert = Certificate.loadPEM(content)
|
||||||
|
trust_root = trustRootFromCertificates([cert])
|
||||||
|
return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
|
||||||
|
|
||||||
|
|
||||||
def get_test_ca_cert_file():
|
def get_test_ca_cert_file():
|
||||||
|
|
|
@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||||
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# grab a hold of the TLS connection, in case it gets torn down
|
||||||
|
server_tls_connection = server_tls_protocol._tlsConnection
|
||||||
|
|
||||||
|
# fish the test server back out of the server-side TLS protocol.
|
||||||
|
http_protocol = server_tls_protocol.wrappedProtocol
|
||||||
|
|
||||||
# give the reactor a pump to get the TLS juices flowing.
|
# give the reactor a pump to get the TLS juices flowing.
|
||||||
self.reactor.pump((0.1,))
|
self.reactor.pump((0.1,))
|
||||||
|
|
||||||
# check the SNI
|
# check the SNI
|
||||||
server_name = server_tls_protocol._tlsConnection.get_servername()
|
server_name = server_tls_connection.get_servername()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
server_name,
|
server_name,
|
||||||
expected_sni,
|
expected_sni,
|
||||||
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
# fish the test server back out of the server-side TLS protocol.
|
return http_protocol
|
||||||
return server_tls_protocol.wrappedProtocol
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _make_get_request(self, uri):
|
def _make_get_request(self, uri):
|
||||||
|
|
334
tests/http/test_proxyagent.py
Normal file
334
tests/http/test_proxyagent.py
Normal file
|
@ -0,0 +1,334 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import treq
|
||||||
|
|
||||||
|
from twisted.internet import interfaces # noqa: F401
|
||||||
|
from twisted.internet.protocol import Factory
|
||||||
|
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||||
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
|
from synapse.http.proxyagent import ProxyAgent
|
||||||
|
|
||||||
|
from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
|
||||||
|
from tests.server import FakeTransport, ThreadedMemoryReactorClock
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTTPFactory = Factory.forProtocol(HTTPChannel)
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixFederationAgentTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
|
||||||
|
def _make_connection(
|
||||||
|
self, client_factory, server_factory, ssl=False, expected_sni=None
|
||||||
|
):
|
||||||
|
"""Builds a test server, and completes the outgoing client connection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_factory (interfaces.IProtocolFactory): the the factory that the
|
||||||
|
application is trying to use to make the outbound connection. We will
|
||||||
|
invoke it to build the client Protocol
|
||||||
|
|
||||||
|
server_factory (interfaces.IProtocolFactory): a factory to build the
|
||||||
|
server-side protocol
|
||||||
|
|
||||||
|
ssl (bool): If true, we will expect an ssl connection and wrap
|
||||||
|
server_factory with a TLSMemoryBIOFactory
|
||||||
|
|
||||||
|
expected_sni (bytes|None): the expected SNI value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IProtocol: the server Protocol returned by server_factory
|
||||||
|
"""
|
||||||
|
if ssl:
|
||||||
|
server_factory = _wrap_server_factory_for_tls(server_factory)
|
||||||
|
|
||||||
|
server_protocol = server_factory.buildProtocol(None)
|
||||||
|
|
||||||
|
# now, tell the client protocol factory to build the client protocol,
|
||||||
|
# and wire the output of said protocol up to the server via
|
||||||
|
# a FakeTransport.
|
||||||
|
#
|
||||||
|
# Normally this would be done by the TCP socket code in Twisted, but we are
|
||||||
|
# stubbing that out here.
|
||||||
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
|
client_protocol.makeConnection(
|
||||||
|
FakeTransport(server_protocol, self.reactor, client_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
# tell the server protocol to send its stuff back to the client, too
|
||||||
|
server_protocol.makeConnection(
|
||||||
|
FakeTransport(client_protocol, self.reactor, server_protocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
if ssl:
|
||||||
|
http_protocol = server_protocol.wrappedProtocol
|
||||||
|
tls_connection = server_protocol._tlsConnection
|
||||||
|
else:
|
||||||
|
http_protocol = server_protocol
|
||||||
|
tls_connection = None
|
||||||
|
|
||||||
|
# give the reactor a pump to get the TLS juices flowing (if needed)
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
if expected_sni is not None:
|
||||||
|
server_name = tls_connection.get_servername()
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
return http_protocol
|
||||||
|
|
||||||
|
def test_http_request(self):
|
||||||
|
agent = ProxyAgent(self.reactor)
|
||||||
|
|
||||||
|
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||||
|
d = agent.request(b"GET", b"http://test.com")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 80)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_https_request(self):
|
||||||
|
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||||
|
|
||||||
|
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||||
|
d = agent.request(b"GET", b"https://test.com/abc")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.4")
|
||||||
|
self.assertEqual(port, 443)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory,
|
||||||
|
_get_test_protocol_factory(),
|
||||||
|
ssl=True,
|
||||||
|
expected_sni=b"test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/abc")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_http_request_via_proxy(self):
|
||||||
|
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
|
||||||
|
|
||||||
|
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||||
|
d = agent.request(b"GET", b"http://test.com")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.5")
|
||||||
|
self.assertEqual(port, 8888)
|
||||||
|
|
||||||
|
# make a test server, and wire up the client
|
||||||
|
http_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"http://test.com")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
def test_https_request_via_proxy(self):
|
||||||
|
agent = ProxyAgent(
|
||||||
|
self.reactor,
|
||||||
|
contextFactory=get_test_https_policy(),
|
||||||
|
https_proxy=b"proxy.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||||
|
d = agent.request(b"GET", b"https://test.com/abc")
|
||||||
|
|
||||||
|
# there should be a pending TCP connection
|
||||||
|
clients = self.reactor.tcpClients
|
||||||
|
self.assertEqual(len(clients), 1)
|
||||||
|
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||||
|
self.assertEqual(host, "1.2.3.5")
|
||||||
|
self.assertEqual(port, 1080)
|
||||||
|
|
||||||
|
# make a test HTTP server, and wire up the client
|
||||||
|
proxy_server = self._make_connection(
|
||||||
|
client_factory, _get_test_protocol_factory()
|
||||||
|
)
|
||||||
|
|
||||||
|
# fish the transports back out so that we can do the old switcheroo
|
||||||
|
s2c_transport = proxy_server.transport
|
||||||
|
client_protocol = s2c_transport.other
|
||||||
|
c2s_transport = client_protocol.transport
|
||||||
|
|
||||||
|
# the FakeTransport is async, so we need to pump the reactor
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
# now there should be a pending CONNECT request
|
||||||
|
self.assertEqual(len(proxy_server.requests), 1)
|
||||||
|
|
||||||
|
request = proxy_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"CONNECT")
|
||||||
|
self.assertEqual(request.path, b"test.com:443")
|
||||||
|
|
||||||
|
# tell the proxy server not to close the connection
|
||||||
|
proxy_server.persistent = True
|
||||||
|
|
||||||
|
# this just stops the http Request trying to do a chunked response
|
||||||
|
# request.setHeader(b"Content-Length", b"0")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
|
||||||
|
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
|
||||||
|
ssl_protocol = ssl_factory.buildProtocol(None)
|
||||||
|
http_server = ssl_protocol.wrappedProtocol
|
||||||
|
|
||||||
|
ssl_protocol.makeConnection(
|
||||||
|
FakeTransport(client_protocol, self.reactor, ssl_protocol)
|
||||||
|
)
|
||||||
|
c2s_transport.other = ssl_protocol
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
server_name = ssl_protocol._tlsConnection.get_servername()
|
||||||
|
expected_sni = b"test.com"
|
||||||
|
self.assertEqual(
|
||||||
|
server_name,
|
||||||
|
expected_sni,
|
||||||
|
"Expected SNI %s but got %s" % (expected_sni, server_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# now there should be a pending request
|
||||||
|
self.assertEqual(len(http_server.requests), 1)
|
||||||
|
|
||||||
|
request = http_server.requests[0]
|
||||||
|
self.assertEqual(request.method, b"GET")
|
||||||
|
self.assertEqual(request.path, b"/abc")
|
||||||
|
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||||
|
request.write(b"result")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
self.reactor.advance(0)
|
||||||
|
|
||||||
|
resp = self.successResultOf(d)
|
||||||
|
body = self.successResultOf(treq.content(resp))
|
||||||
|
self.assertEqual(body, b"result")
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_server_factory_for_tls(factory, sanlist=None):
|
||||||
|
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
|
||||||
|
|
||||||
|
The resultant factory will create a TLS server which presents a certificate
|
||||||
|
signed by our test CA, valid for the domains in `sanlist`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory (interfaces.IProtocolFactory): protocol factory to wrap
|
||||||
|
sanlist (iterable[bytes]): list of domains the cert should be valid for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IProtocolFactory
|
||||||
|
"""
|
||||||
|
if sanlist is None:
|
||||||
|
sanlist = [b"DNS:test.com"]
|
||||||
|
|
||||||
|
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
|
||||||
|
return TLSMemoryBIOFactory(
|
||||||
|
connection_creator, isClient=False, wrappedFactory=factory
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_test_protocol_factory():
|
||||||
|
"""Get a protocol Factory which will build an HTTPChannel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
interfaces.IProtocolFactory
|
||||||
|
"""
|
||||||
|
server_factory = Factory.forProtocol(HTTPChannel)
|
||||||
|
|
||||||
|
# Request.finish expects the factory to have a 'log' method.
|
||||||
|
server_factory.log = _log_request
|
||||||
|
|
||||||
|
return server_factory
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request(request):
|
||||||
|
"""Implements Factory.log, which is expected by Request.finish"""
|
||||||
|
logger.info("Completed request %s", request)
|
|
@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["start_pushers"] = True
|
config["start_pushers"] = True
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(config=config, simple_http_client=m)
|
hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
|
|
@ -395,11 +395,24 @@ class FakeTransport(object):
|
||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
if self._protocol:
|
if self._protocol:
|
||||||
self._protocol.connectionLost(reason)
|
self._protocol.connectionLost(reason)
|
||||||
|
|
||||||
|
# if we still have data to write, delay until that is done
|
||||||
|
if self.buffer:
|
||||||
|
logger.info(
|
||||||
|
"FakeTransport: Delaying disconnect until buffer is flushed"
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def abortConnection(self):
|
def abortConnection(self):
|
||||||
logger.info("FakeTransport: abortConnection()")
|
logger.info("FakeTransport: abortConnection()")
|
||||||
self.loseConnection()
|
|
||||||
|
if not self.disconnecting:
|
||||||
|
self.disconnecting = True
|
||||||
|
if self._protocol:
|
||||||
|
self._protocol.connectionLost(None)
|
||||||
|
|
||||||
|
self.disconnected = True
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self):
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
|
@ -430,6 +443,9 @@ class FakeTransport(object):
|
||||||
self._reactor.callLater(0.0, _produce)
|
self._reactor.callLater(0.0, _produce)
|
||||||
|
|
||||||
def write(self, byt):
|
def write(self, byt):
|
||||||
|
if self.disconnecting:
|
||||||
|
raise Exception("Writing to disconnecting FakeTransport")
|
||||||
|
|
||||||
self.buffer = self.buffer + byt
|
self.buffer = self.buffer + byt
|
||||||
|
|
||||||
# always actually do the write asynchronously. Some protocols (notably the
|
# always actually do the write asynchronously. Some protocols (notably the
|
||||||
|
@ -474,6 +490,10 @@ class FakeTransport(object):
|
||||||
if self.buffer and self.autoflush:
|
if self.buffer and self.autoflush:
|
||||||
self._reactor.callLater(0.0, self.flush)
|
self._reactor.callLater(0.0, self.flush)
|
||||||
|
|
||||||
|
if not self.buffer and self.disconnecting:
|
||||||
|
logger.info("FakeTransport: Buffer now empty, completing disconnect")
|
||||||
|
self.disconnected = True
|
||||||
|
|
||||||
|
|
||||||
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue