Correctly handle x_forwaded listener option

This commit is contained in:
Erik Johnston 2015-06-12 17:13:23 +01:00
parent fd2c07bfed
commit 9c5fc81c2d
3 changed files with 31 additions and 14 deletions

View file

@ -34,7 +34,7 @@ from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
@ -199,7 +199,7 @@ class SynapseHomeServer(HomeServer):
port, port,
SynapseSite( SynapseSite(
"synapse.access.https", "synapse.access.https",
config, listener_config,
root_resource, root_resource,
), ),
self.tls_context_factory, self.tls_context_factory,
@ -210,7 +210,7 @@ class SynapseHomeServer(HomeServer):
port, port,
SynapseSite( SynapseSite(
"synapse.access.https", "synapse.access.https",
config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=bind_address
@ -441,6 +441,28 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class XForwardedForRequest(Request):
def __init__(self, *args, **kw):
Request.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
def XForwardedFactory(*args, **kwargs):
return XForwardedForRequest(*args, **kwargs)
class SynapseSite(Site): class SynapseSite(Site):
""" """
Subclass of a twisted http Site that does access logging with python's Subclass of a twisted http Site that does access logging with python's
@ -448,7 +470,8 @@ class SynapseSite(Site):
""" """
def __init__(self, logger_name, config, resource, *args, **kwargs): def __init__(self, logger_name, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)
if config.captcha_ip_origin_is_x_forwarded: if config.get("x_forwarded", False):
self.requestFactory = XForwardedFactory
self._log_formatter = proxiedLogFormatter self._log_formatter = proxiedLogFormatter
else: else:
self._log_formatter = combinedLogFormatter self._log_formatter = combinedLogFormatter

View file

@ -157,6 +157,8 @@ class ServerConfig(Config):
bind_address: '' bind_address: ''
type: http type: http
x_forwarded: False
resources: resources:
- names: [client, webclient] - names: [client, webclient]
compress: true compress: true

View file

@ -132,16 +132,8 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get) setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# May be an X-Forwarding-For header depending on config # X-Forwarded-For is handled by our custom request type.
ip_addr = request.getClientIP() return request.getClientIP()
if self.config.captcha_ip_origin_is_x_forwarded:
# use the header
if request.requestHeaders.hasHeader("X-Forwarded-For"):
ip_addr = request.requestHeaders.getRawHeaders(
"X-Forwarded-For"
)[0]
return ip_addr
def is_mine(self, domain_specific_string): def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname return domain_specific_string.domain == self.hostname