diff --git a/lib/ansible/galaxy/collection.py b/lib/ansible/galaxy/collection.py index 72e4d9a4238..be037911791 100644 --- a/lib/ansible/galaxy/collection.py +++ b/lib/ansible/galaxy/collection.py @@ -1384,7 +1384,23 @@ def _download_file(url, b_path, expected_hash, validate_certs, headers=None): def _extract_tar_dir(tar, dirname, b_dest): """ Extracts a directory from a collection tar. """ - tar_member = tar.getmember(to_native(dirname, errors='surrogate_or_strict')) + member_names = [to_native(dirname, errors='surrogate_or_strict')] + + # Create list of members with and without trailing separator + if not member_names[-1].endswith(os.path.sep): + member_names.append(member_names[-1] + os.path.sep) + + # Try all of the member names and stop on the first one that are able to successfully get + for member in member_names: + try: + tar_member = tar.getmember(member) + except KeyError: + continue + break + else: + # If we still can't find the member, raise a nice error. + raise AnsibleError("Unable to extract '%s' from collection" % to_native(member, errors='surrogate_or_strict')) + b_dir_path = os.path.join(b_dest, to_bytes(dirname, errors='surrogate_or_strict')) b_parent_path = os.path.dirname(b_dir_path) @@ -1403,7 +1419,8 @@ def _extract_tar_dir(tar, dirname, b_dest): os.symlink(b_link_path, b_dir_path) else: - os.mkdir(b_dir_path, 0o0755) + if not os.path.isdir(b_dir_path): + os.mkdir(b_dir_path, 0o0755) def _extract_tar_file(tar, filename, b_dest, b_temp_path, expected_hash=None): diff --git a/test/units/cli/galaxy/test_collection_extract_tar.py b/test/units/cli/galaxy/test_collection_extract_tar.py new file mode 100644 index 00000000000..526442cc9de --- /dev/null +++ b/test/units/cli/galaxy/test_collection_extract_tar.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + +from ansible.errors import AnsibleError +from ansible.galaxy.collection import _extract_tar_dir + + +@pytest.fixture +def fake_tar_obj(mocker): + m_tarfile = mocker.Mock() + m_tarfile.type = mocker.Mock(return_value=b'99') + m_tarfile.SYMTYPE = mocker.Mock(return_value=b'22') + + return m_tarfile + + +def test_extract_tar_member_trailing_sep(mocker): + m_tarfile = mocker.Mock() + m_tarfile.getmember = mocker.Mock(side_effect=KeyError) + + with pytest.raises(AnsibleError, match='Unable to extract'): + _extract_tar_dir(m_tarfile, '/some/dir/', b'/some/dest') + + assert m_tarfile.getmember.call_count == 1 + + +def test_extract_tar_member_no_trailing_sep(mocker): + m_tarfile = mocker.Mock() + m_tarfile.getmember = mocker.Mock(side_effect=KeyError) + + with pytest.raises(AnsibleError, match='Unable to extract'): + _extract_tar_dir(m_tarfile, '/some/dir', b'/some/dest') + + assert m_tarfile.getmember.call_count == 2 + + +def test_extract_tar_dir_exists(mocker, fake_tar_obj): + mocker.patch('os.makedirs', return_value=None) + m_makedir = mocker.patch('os.mkdir', return_value=None) + mocker.patch('os.path.isdir', return_value=True) + + _extract_tar_dir(fake_tar_obj, '/some/dir', b'/some/dest') + + assert not m_makedir.called + + +def test_extract_tar_dir_does_not_exist(mocker, fake_tar_obj): + mocker.patch('os.makedirs', return_value=None) + m_makedir = mocker.patch('os.mkdir', return_value=None) + mocker.patch('os.path.isdir', return_value=False) + + _extract_tar_dir(fake_tar_obj, '/some/dir', b'/some/dest') + + assert m_makedir.called + assert m_makedir.call_args[0] == (b'/some/dir', 0o0755)