Fix @contextmanager leak on exception. (#21031)
* Fix @contextmanager leak on exception.
* Fix test leaks of global module args cache.
(cherry picked from commit 272ff10fa1
)
This commit is contained in:
parent
6176c95838
commit
cb93ecaef9
4 changed files with 25 additions and 13 deletions
|
@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()):
|
|||
context manager that temporarily masks the test runner's values for stdin and argv
|
||||
"""
|
||||
real_stdin = sys.stdin
|
||||
real_argv = sys.argv
|
||||
|
||||
if PY3:
|
||||
sys.stdin = StringIO(stdin_data)
|
||||
sys.stdin.buffer = BytesIO(to_bytes(stdin_data))
|
||||
fake_stream = StringIO(stdin_data)
|
||||
fake_stream.buffer = BytesIO(to_bytes(stdin_data))
|
||||
else:
|
||||
sys.stdin = BytesIO(to_bytes(stdin_data))
|
||||
fake_stream = BytesIO(to_bytes(stdin_data))
|
||||
|
||||
real_argv = sys.argv
|
||||
sys.argv = argv_data
|
||||
yield
|
||||
sys.stdin = real_stdin
|
||||
sys.argv = real_argv
|
||||
try:
|
||||
sys.stdin = fake_stream
|
||||
sys.argv = argv_data
|
||||
|
||||
yield
|
||||
finally:
|
||||
sys.stdin = real_stdin
|
||||
sys.argv = real_argv
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -56,13 +60,18 @@ def swap_stdout():
|
|||
context manager that temporarily replaces stdout for tests that need to verify output
|
||||
"""
|
||||
old_stdout = sys.stdout
|
||||
|
||||
if PY3:
|
||||
fake_stream = StringIO()
|
||||
else:
|
||||
fake_stream = BytesIO()
|
||||
sys.stdout = fake_stream
|
||||
yield fake_stream
|
||||
sys.stdout = old_stdout
|
||||
|
||||
try:
|
||||
sys.stdout = fake_stream
|
||||
|
||||
yield fake_stream
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
|
||||
|
||||
class ModuleTestCase(unittest.TestCase):
|
||||
|
|
|
@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
|
|||
from ansible.module_utils import basic
|
||||
|
||||
# test basic log invocation
|
||||
basic._ANSIBLE_ARGS = None
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec=dict(
|
||||
foo = dict(default=True, type='bool'),
|
||||
|
|
|
@ -315,6 +315,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
|||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
basic._ANSIBLE_ARGS = None
|
||||
self.assertRaises(
|
||||
SystemExit,
|
||||
basic.AnsibleModule,
|
||||
|
@ -331,6 +332,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
|||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
basic._ANSIBLE_ARGS = None
|
||||
self.assertRaises(
|
||||
SystemExit,
|
||||
basic.AnsibleModule,
|
||||
|
@ -583,12 +585,11 @@ class TestModuleUtilsBasic(ModuleTestCase):
|
|||
|
||||
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
|
||||
from ansible.module_utils import basic
|
||||
basic._ANSIBLE_ARGS = None
|
||||
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"}))
|
||||
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
|
||||
basic._ANSIBLE_ARGS = None
|
||||
am = basic.AnsibleModule(
|
||||
argument_spec = dict(),
|
||||
)
|
||||
|
|
|
@ -703,6 +703,7 @@ def test_distribution_version():
|
|||
|
||||
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}))
|
||||
with swap_stdin_and_argv(stdin_data=args):
|
||||
basic._ANSIBLE_ARGS = None
|
||||
module = basic.AnsibleModule(argument_spec=dict())
|
||||
|
||||
for t in TESTSETS:
|
||||
|
|
Loading…
Reference in a new issue