From 102ee6a3b47f2c7aebaa90356e96bd282236df23 Mon Sep 17 00:00:00 2001
From: Ben Doherty <ben@thinkoomph.com>
Date: Tue, 31 May 2016 18:31:07 -0400
Subject: [PATCH] Some refactoring:

* rename archive -> arcfile (where it's a file descriptor)
* additional return
* simplify logic around 'archive?' flag
* maintain os separator after arcroot
* use function instead of lambda for filter, ensure file exists before file.cmp'ing it
* track errored files and fail if there are any
---
 lib/ansible/modules/extras/files/archive.py | 40 +++++++++++++++------
 1 file changed, 29 insertions(+), 11 deletions(-)

diff --git a/lib/ansible/modules/extras/files/archive.py b/lib/ansible/modules/extras/files/archive.py
index 53df2d109d4..3ef0dcb20cc 100644
--- a/lib/ansible/modules/extras/files/archive.py
+++ b/lib/ansible/modules/extras/files/archive.py
@@ -148,11 +148,11 @@ def main():
             expanded_paths.append(path)
 
     if len(expanded_paths) == 0:
-        module.fail_json(path=', '.join(paths), expanded_paths=', '.join(expanded_paths), msg='Error, no source paths were found')
+        return module.fail_json(path=', '.join(paths), expanded_paths=', '.join(expanded_paths), msg='Error, no source paths were found')
 
     # If we actually matched multiple files or TRIED to, then
     # treat this as a multi-file archive
-    archive = globby or len(expanded_paths) > 1 or any(os.path.isdir(path) for path in expanded_paths)
+    archive = globby or os.path.isdir(expanded_paths[0]) or len(expanded_paths) > 1
 
     # Default created file name (for single-file archives) to
     # <file>.<compression>
@@ -181,6 +181,8 @@ def main():
             if i < len(arcroot):
                 arcroot = os.path.dirname(arcroot[0:i+1])
 
+            arcroot += os.sep
+
         # Don't allow archives to be created anywhere within paths to be removed
         if remove and os.path.isdir(path) and creates.startswith(path):
             module.fail_json(path=', '.join(paths), msg='Error, created archive can not be contained in source paths when remove=True')
@@ -219,7 +221,9 @@ def main():
             try:
                 # Easier compression using tarfile module
                 if compression == 'gz' or compression == 'bz2':
-                    archive = tarfile.open(creates, 'w|' + compression)
+                    arcfile = tarfile.open(creates, 'w|' + compression)
+
+                    arcfile.add(arcroot, os.path.basename(arcroot), recursive=False)
 
                     for path in archive_paths:
                         basename = ''
@@ -228,12 +232,23 @@ def main():
                         if os.path.isdir(path) and not path.endswith(os.sep + '.'):
                             basename = os.path.basename(path) + os.sep
 
-                        archive.add(path, path[len(arcroot):], filter=lambda f: not filecmp.cmp(f.name, creates) and f)
-                        successes.append(path)
+                        try:
+                            def exclude_creates(f):
+                                if os.path.exists(f.name) and not filecmp.cmp(f.name, creates):
+                                    return f
+
+                                return None
+
+                            arcfile.add(path, basename + path[len(arcroot):], filter=exclude_creates)
+                            successes.append(path)
+
+                        except:
+                            e = get_exception()
+                            errors.append('error adding %s: %s' % (path, str(e)))
 
                 # Slightly more difficult (and less efficient!) compression using zipfile module
                 elif compression == 'zip':
-                    archive = zipfile.ZipFile(creates, 'w', zipfile.ZIP_DEFLATED)
+                    arcfile = zipfile.ZipFile(creates, 'w', zipfile.ZIP_DEFLATED)
 
                     for path in archive_paths:
                         basename = ''
@@ -244,23 +259,26 @@ def main():
 
                         for dirpath, dirnames, filenames in os.walk(path, topdown=True):
                             for dirname in dirnames:
-                                archive.write(dirpath + os.sep + dirname, basename + dirname)
+                                arcfile.write(dirpath + os.sep + dirname, basename + dirname)
                             for filename in filenames:
                                 fullpath = dirpath + os.sep + filename
 
                                 if not filecmp.cmp(fullpath, creates):
-                                    archive.write(fullpath, basename + filename)
+                                    arcfile.write(fullpath, basename + filename)
 
                         successes.append(path)
 
             except OSError:
                 e = get_exception()
-                module.fail_json(msg='Error when writing zip archive at %s: %s' % (creates, str(e)))
+                module.fail_json(msg='Error when writing %s archive at %s: %s' % (compression == 'zip' and 'zip' or ('tar.' + compression), creates, str(e)))
 
-            if archive:
-                archive.close()
+            if arcfile:
+                arcfile.close()
                 state = 'archive'
 
+            if len(errors) > 0:
+                module.fail_json(msg='Errors when writing archive at %s: %s' % (creates, '; '.join(errors)))
+
         if state in ['archive', 'incomplete'] and remove:
             for path in successes:
                 try: