Add RedirectHandler class and factory function for controlling redirects in urllib2

This commit is contained in:
Matt Martz 2016-02-05 12:12:04 -06:00
parent b2c0abe998
commit 0e57c577f4

View file

@ -417,7 +417,7 @@ class RequestWithMethod(urllib2.Request):
def __init__(self, url, method, data=None, headers=None):
if headers is None:
headers = {}
self._method = method
self._method = method.upper()
urllib2.Request.__init__(self, url, data, headers)
def get_method(self):
@ -427,6 +427,55 @@ class RequestWithMethod(urllib2.Request):
return urllib2.Request.get_method(self)
def RedirectHandlerFactory(follow_redirects=None):
"""This is a class factory that closes over the value of
``follow_redirects`` so that the RedirectHandler class has access to
that value without having to use globals, and potentially cause problems
where ``open_url`` or ``fetch_url`` are used multiple times in a module.
"""
class RedirectHandler(urllib2.HTTPRedirectHandler):
"""This is an implementation of a RedirectHandler to match the
functionality provided by httplib2. It will utilize the value of
``follow_redirects`` that is passed into ``RedirectHandlerFactory``
to determine how redirects should be handled in urllib2.
"""
def redirect_request(self, req, fp, code, msg, hdrs, newurl):
if follow_redirects == 'urllib2':
return urllib2.HTTPRedirectHandler.redirect_request(self, req,
fp, code,
msg, hdrs,
newurl)
if follow_redirects in [None, 'no', 'none']:
raise urllib2.HTTPError(newurl, code, msg, hdrs, fp)
do_redirect = False
if follow_redirects in ['all', 'yes']:
do_redirect = (code >= 300 and code < 400)
elif follow_redirects == 'safe':
m = req.get_method()
do_redirect = (code >= 300 and code < 400 and m in ('GET', 'HEAD'))
if do_redirect:
# be conciliant with URIs containing a space
newurl = newurl.replace(' ', '%20')
newheaders = dict((k,v) for k,v in req.headers.items()
if k.lower() not in ("content-length", "content-type")
)
return urllib2.Request(newurl,
headers=newheaders,
origin_req_host=req.get_origin_req_host(),
unverifiable=True)
else:
raise urllib2.HTTPError(req.get_full_url(), code, msg, hdrs,
fp)
return RedirectHandler
class SSLValidationHandler(urllib2.BaseHandler):
'''
A custom handler class for SSL validation.
@ -604,7 +653,8 @@ class SSLValidationHandler(urllib2.BaseHandler):
# Rewrite of fetch_url to not require the module environment
def open_url(url, data=None, headers=None, method=None, use_proxy=True,
force=False, last_mod_time=None, timeout=10, validate_certs=True,
url_username=None, url_password=None, http_agent=None, force_basic_auth=False):
url_username=None, url_password=None, http_agent=None,
force_basic_auth=False, follow_redirects='urllib2'):
'''
Fetches a file from an HTTP/FTP server using urllib2
'''
@ -681,6 +731,9 @@ def open_url(url, data=None, headers=None, method=None, use_proxy=True,
if hasattr(socket, 'create_connection') and CustomHTTPSHandler:
handlers.append(CustomHTTPSHandler)
if follow_redirects != 'urllib2':
handlers.append(RedirectHandlerFactory(follow_redirects))
opener = urllib2.build_opener(*handlers)
urllib2.install_opener(opener)
@ -750,7 +803,8 @@ def url_argument_spec():
)
def fetch_url(module, url, data=None, headers=None, method=None,
use_proxy=True, force=False, last_mod_time=None, timeout=10):
use_proxy=True, force=False, last_mod_time=None, timeout=10,
follow_redirects=False):
'''
Fetches a file from an HTTP/FTP server using urllib2. Requires the module environment
'''
@ -767,14 +821,16 @@ def fetch_url(module, url, data=None, headers=None, method=None,
password = module.params.get('url_password', '')
http_agent = module.params.get('http_agent', None)
force_basic_auth = module.params.get('force_basic_auth', '')
follow_redirects = follow_redirects or module.params.get('follow_redirects', 'urllib2')
r = None
info = dict(url=url)
try:
r = open_url(url, data=data, headers=headers, method=method,
use_proxy=use_proxy, force=force, last_mod_time=last_mod_time, timeout=timeout,
validate_certs=validate_certs, url_username=username,
url_password=password, http_agent=http_agent, force_basic_auth=force_basic_auth)
use_proxy=use_proxy, force=force, last_mod_time=last_mod_time, timeout=timeout,
validate_certs=validate_certs, url_username=username,
url_password=password, http_agent=http_agent, force_basic_auth=force_basic_auth,
follow_redirects=follow_redirects)
info.update(r.info())
info['url'] = r.geturl() # The URL goes in too, because of redirects.
info.update(dict(msg="OK (%s bytes)" % r.headers.get('Content-Length', 'unknown'), status=200))
@ -787,7 +843,7 @@ def fetch_url(module, url, data=None, headers=None, method=None,
except (ConnectionError, ValueError), e:
module.fail_json(msg=str(e))
except urllib2.HTTPError, e:
info.update(dict(msg=str(e), status=e.code))
info.update(dict(msg=str(e), status=e.code, **e.info()))
except urllib2.URLError, e:
code = int(getattr(e, 'code', -1))
info.update(dict(msg="Request failed: %s" % str(e), status=code))