Use getClientAddress instead of getClientIP. (#12599)

getClientIP was deprecated in Twisted 18.4.0, which also added
getClientAddress. The Synapse minimum version for Twisted is
currently 18.9.0, so all supported versions have the new API.
This commit is contained in:
Patrick Cloke 2022-05-04 14:11:21 -04:00 committed by GitHub
parent 116a4c8340
commit 7fbf42499d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 62 additions and 46 deletions

1
changelog.d/12599.misc Normal file
View file

@ -0,0 +1 @@
Use `getClientAddress` instead of the deprecated `getClientIP`.

View file

@ -187,7 +187,7 @@ class Auth:
Once get_user_by_req has set up the opentracing span, this does the actual work. Once get_user_by_req has set up the opentracing span, this does the actual work.
""" """
try: try:
ip_addr = request.getClientIP() ip_addr = request.getClientAddress().host
user_agent = get_request_user_agent(request) user_agent = get_request_user_agent(request)
access_token = self.get_access_token_from_request(request) access_token = self.get_access_token_from_request(request)
@ -356,7 +356,7 @@ class Auth:
return None, None, None return None, None, None
if app_service.ip_range_whitelist: if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientIP()) ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
return None, None, None return None, None, None

View file

@ -551,7 +551,7 @@ class AuthHandler:
await self.store.set_ui_auth_clientdict(sid, clientdict) await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = get_request_user_agent(request) user_agent = get_request_user_agent(request)
clientip = request.getClientIP() clientip = request.getClientAddress().host
await self.store.add_user_agent_ip_to_ui_auth_session( await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip session.session_id, user_agent, clientip

View file

@ -92,7 +92,7 @@ class IdentityHandler:
""" """
await self._3pid_validation_ratelimiter_ip.ratelimit( await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientIP()) None, (medium, request.getClientAddress().host)
) )
await self._3pid_validation_ratelimiter_address.ratelimit( await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address) None, (medium, address)

View file

@ -468,7 +468,7 @@ class SsoHandler:
auth_provider_id, auth_provider_id,
remote_user_id, remote_user_id,
get_request_user_agent(request), get_request_user_agent(request),
request.getClientIP(), request.getClientAddress().host,
) )
new_user = True new_user = True
elif self._sso_update_profile_information: elif self._sso_update_profile_information:
@ -928,7 +928,7 @@ class SsoHandler:
session.auth_provider_id, session.auth_provider_id,
session.remote_user_id, session.remote_user_id,
get_request_user_agent(request), get_request_user_agent(request),
request.getClientIP(), request.getClientAddress().host,
) )
logger.info( logger.info(

View file

@ -238,7 +238,7 @@ class SynapseRequest(Request):
request_id, request_id,
request=ContextRequest( request=ContextRequest(
request_id=request_id, request_id=request_id,
ip_address=self.getClientIP(), ip_address=self.getClientAddress().host,
site_tag=self.synapse_site.site_tag, site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point. # The requester is going to be unknown at this point.
requester=None, requester=None,
@ -381,7 +381,7 @@ class SynapseRequest(Request):
self.synapse_site.access_logger.debug( self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientAddress().host,
self.synapse_site.site_tag, self.synapse_site.site_tag,
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
@ -429,7 +429,7 @@ class SynapseRequest(Request):
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientAddress().host,
self.synapse_site.site_tag, self.synapse_site.site_tag,
requester, requester,
processing_time, processing_time,

View file

@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(), tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(), tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(), tags.PEER_HOST_IPV6: request.getClientAddress().host,
} }
request_name = request.request_metrics.name request_name = request.request_metrics.name

View file

@ -112,7 +112,7 @@ class AuthRestServlet(RestServlet):
try: try:
await self.auth_handler.add_oob_auth( await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientIP() LoginType.RECAPTCHA, authdict, request.getClientAddress().host
) )
except LoginError as e: except LoginError as e:
# Authentication failed, let user try again # Authentication failed, let user try again
@ -132,7 +132,7 @@ class AuthRestServlet(RestServlet):
try: try:
await self.auth_handler.add_oob_auth( await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientIP() LoginType.TERMS, authdict, request.getClientAddress().host
) )
except LoginError as e: except LoginError as e:
# Authentication failed, let user try again # Authentication failed, let user try again
@ -161,7 +161,9 @@ class AuthRestServlet(RestServlet):
try: try:
await self.auth_handler.add_oob_auth( await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() LoginType.REGISTRATION_TOKEN,
authdict,
request.getClientAddress().host,
) )
except LoginError as e: except LoginError as e:
html = self.registration_token_template.render( html = self.registration_token_template.render(

View file

@ -176,7 +176,7 @@ class LoginRestServlet(RestServlet):
if appservice.is_rate_limited(): if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit( await self._address_ratelimiter.ratelimit(
None, request.getClientIP() None, request.getClientAddress().host
) )
result = await self._do_appservice_login( result = await self._do_appservice_login(
@ -188,19 +188,25 @@ class LoginRestServlet(RestServlet):
self.jwt_enabled self.jwt_enabled
and login_submission["type"] == LoginRestServlet.JWT_TYPE and login_submission["type"] == LoginRestServlet.JWT_TYPE
): ):
await self._address_ratelimiter.ratelimit(None, request.getClientIP()) await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_jwt_login( result = await self._do_jwt_login(
login_submission, login_submission,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
) )
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP()) await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_token_login( result = await self._do_token_login(
login_submission, login_submission,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,
) )
else: else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP()) await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_other_login( result = await self._do_other_login(
login_submission, login_submission,
should_issue_refresh_token=should_issue_refresh_token, should_issue_refresh_token=should_issue_refresh_token,

View file

@ -352,7 +352,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
if self.inhibit_user_in_use_error: if self.inhibit_user_in_use_error:
return 200, {"available": True} return 200, {"available": True}
ip = request.getClientIP() ip = request.getClientAddress().host
with self.ratelimiter.ratelimit(ip) as wait_deferred: with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred await wait_deferred
@ -394,7 +394,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
) )
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
if not self.hs.config.registration.enable_registration: if not self.hs.config.registration.enable_registration:
raise SynapseError( raise SynapseError(
@ -441,7 +441,7 @@ class RegisterRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP() client_addr = request.getClientAddress().host
await self.ratelimiter.ratelimit(None, client_addr, update=False) await self.ratelimiter.ratelimit(None, client_addr, update=False)

View file

@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10" request.getClientAddress.return_value.host = "192.168.10.10"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request)) requester = self.get_success(self.auth.get_user_by_req(request))
@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42" request.getClientAddress.return_value.host = "131.111.8.42"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
f = self.get_failure( f = self.get_failure(
@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = simple_async_mock({"hidden": False}) self.store.get_device = simple_async_mock({"hidden": False})
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = simple_async_mock(None) self.store.get_device = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id] request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request)) self.get_success(self.auth.get_user_by_req(request))
@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request)) self.get_success(self.auth.get_user_by_req(request))

View file

@ -204,7 +204,7 @@ def _mock_request():
mock = Mock( mock = Mock(
spec=[ spec=[
"finish", "finish",
"getClientIP", "getClientAddress",
"getHeader", "getHeader",
"setHeader", "setHeader",
"setResponseCode", "setResponseCode",

View file

@ -1300,7 +1300,7 @@ def _build_callback_request(
"getCookie", "getCookie",
"cookies", "cookies",
"requestHeaders", "requestHeaders",
"getClientIP", "getClientAddress",
"getHeader", "getHeader",
] ]
) )
@ -1310,5 +1310,5 @@ def _build_callback_request(
request.args = {} request.args = {}
request.args[b"code"] = [code.encode("utf-8")] request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address request.getClientAddress.return_value.host = ip_address
return request return request

View file

@ -352,7 +352,7 @@ def _mock_request():
mock = Mock( mock = Mock(
spec=[ spec=[
"finish", "finish",
"getClientIP", "getClientAddress",
"getHeader", "getHeader",
"setHeader", "setHeader",
"setResponseCode", "setResponseCode",

View file

@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(port, 8765) self.assertEqual(port, 8765)
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
# Set up the server side protocol # Set up the server side protocol
channel = self.site.buildProtocol(None) server_address = IPv4Address("TCP", host, port)
channel = self.site.buildProtocol((host, port))
# hook into the channel's request factory so that we can keep a record # hook into the channel's request factory so that we can keep a record
# of the requests # of the requests
@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol channel, self.reactor, client_protocol, server_address, client_address
) )
client_protocol.makeConnection(client_to_server_transport) client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport( server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel client_protocol, self.reactor, channel, client_address, server_address
) )
channel.makeConnection(server_to_client_transport) channel.makeConnection(server_to_client_transport)
@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(port, repl_port) self.assertEqual(port, repl_port)
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))
# Set up the server side protocol # Set up the server side protocol
channel = self._hs_to_site[hs].buildProtocol(None) server_address = IPv4Address("TCP", host, port)
channel = self._hs_to_site[hs].buildProtocol((host, port))
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol channel, self.reactor, client_protocol, server_address, client_address
) )
client_protocol.makeConnection(client_to_server_transport) client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport( server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel client_protocol, self.reactor, channel, client_address, server_address
) )
channel.makeConnection(server_to_client_transport) channel.makeConnection(server_to_client_transport)

View file

@ -181,7 +181,7 @@ class FakeChannel:
self.resource_usage = _self.logcontext.get_resource_usage() self.resource_usage = _self.logcontext.get_resource_usage()
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423) return address.IPv4Address("TCP", self._ip, 3423)
@ -562,7 +562,10 @@ class FakeTransport:
""" """
_peer_address: Optional[IAddress] = attr.ib(default=None) _peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer""" """The value to be returned by getPeer"""
_host_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returned by getHost"""
disconnecting = False disconnecting = False
disconnected = False disconnected = False
@ -571,11 +574,11 @@ class FakeTransport:
producer = attr.ib(default=None) producer = attr.ib(default=None)
autoflush = attr.ib(default=True) autoflush = attr.ib(default=True)
def getPeer(self): def getPeer(self) -> Optional[IAddress]:
return self._peer_address return self._peer_address
def getHost(self): def getHost(self) -> Optional[IAddress]:
return None return self._host_address
def loseConnection(self, reason=None): def loseConnection(self, reason=None):
if not self.disconnecting: if not self.disconnecting: