Starting to add additional unit tests for VariableManager

Required some rewiring in inventory code to make sure we're using
the DataLoader class for some data file operations, which makes mocking
them much easier.

Also identified two corner cases not currently handled by the code, related
to inventory variable sources and which one "wins". Also noticed we weren't
properly merging variables from multiple group/host_var file locations
(inventory directory vs. playbook directory locations) so fixed as well.
This commit is contained in:
James Cammarata 2015-09-04 16:41:38 -04:00
parent 87f75a50ad
commit ff9f5d7dc8
13 changed files with 233 additions and 54 deletions

View file

@ -34,7 +34,6 @@ from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleOptionsError
from ansible.utils.unicode import to_bytes
from ansible.utils.display import Display
from ansible.utils.path import is_executable
class SortedOptParser(optparse.OptionParser):
'''Optparser which sorts the options by opt before outputting --help'''
@ -479,7 +478,7 @@ class CLI(object):
return t
@staticmethod
def read_vault_password_file(vault_password_file):
def read_vault_password_file(vault_password_file, loader):
"""
Read a vault password from a file or if executable, execute the script and
retrieve password from STDOUT
@ -489,7 +488,7 @@ class CLI(object):
if not os.path.exists(this_path):
raise AnsibleError("The vault password file %s was not found" % this_path)
if is_executable(this_path):
if loader.is_executable(this_path):
try:
# STDERR not captured to make it easier for users to prompt for input in their scripts
p = subprocess.Popen(this_path, stdout=subprocess.PIPE)

View file

@ -95,13 +95,16 @@ class AdHocCLI(CLI):
(sshpass, becomepass) = self.ask_passwords()
passwords = { 'conn_pass': sshpass, 'become_pass': becomepass }
loader = DataLoader()
if self.options.vault_password_file:
# read vault_pass from a file
vault_pass = CLI.read_vault_password_file(self.options.vault_password_file)
vault_pass = CLI.read_vault_password_file(self.options.vault_password_file, loader=loader)
loader.set_vault_password(vault_pass)
elif self.options.ask_vault_pass:
vault_pass = self.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=False, confirm_new=False)[0]
loader.set_vault_password(vault_pass)
loader = DataLoader(vault_password=vault_pass)
variable_manager = VariableManager()
variable_manager.extra_vars = load_extra_vars(loader=loader, options=self.options)

View file

@ -89,13 +89,15 @@ class PlaybookCLI(CLI):
(sshpass, becomepass) = self.ask_passwords()
passwords = { 'conn_pass': sshpass, 'become_pass': becomepass }
loader = DataLoader()
if self.options.vault_password_file:
# read vault_pass from a file
vault_pass = CLI.read_vault_password_file(self.options.vault_password_file)
vault_pass = CLI.read_vault_password_file(self.options.vault_password_file, loader=loader)
loader.set_vault_password(vault_pass)
elif self.options.ask_vault_pass:
vault_pass = self.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=False, confirm_new=False)[0]
loader = DataLoader(vault_password=vault_pass)
loader.set_vault_password(vault_pass)
# initial error check, to make sure all specified playbooks are accessible
# before we start running anything through the playbook executor

View file

@ -104,9 +104,9 @@ class Inventory(object):
all.add_host(Host(tokens[0], tokens[1]))
else:
all.add_host(Host(x))
elif os.path.exists(host_list):
elif self._loader.path_exists(host_list):
#TODO: switch this to a plugin loader and a 'condition' per plugin on which it should be tried, restoring 'inventory pllugins'
if os.path.isdir(host_list):
if self._loader.is_directory(host_list):
# Ensure basedir is inside the directory
host_list = os.path.join(self.host_list, "")
self.parser = InventoryDirectory(loader=self._loader, filename=host_list)
@ -595,14 +595,14 @@ class Inventory(object):
""" did inventory come from a file? """
if not isinstance(self.host_list, basestring):
return False
return os.path.exists(self.host_list)
return self._loader.path_exists(self.host_list)
def basedir(self):
""" if inventory came from a file, what's the directory? """
dname = self.host_list
if not self.is_file():
dname = None
elif os.path.isdir(self.host_list):
elif self._loader.is_directory(self.host_list):
dname = self.host_list
else:
dname = os.path.dirname(self.host_list)

View file

@ -29,7 +29,6 @@ from ansible.inventory.host import Host
from ansible.inventory.group import Group
from ansible.utils.vars import combine_vars
from ansible.utils.path import is_executable
from ansible.inventory.ini import InventoryParser as InventoryINIParser
from ansible.inventory.script import InventoryScript
@ -54,7 +53,7 @@ def get_file_parser(hostsfile, loader):
except:
pass
if is_executable(hostsfile):
if loader.is_executable(hostsfile):
try:
parser = InventoryScript(loader=loader, filename=hostsfile)
processed = True
@ -65,10 +64,10 @@ def get_file_parser(hostsfile, loader):
if not processed:
try:
parser = InventoryINIParser(filename=hostsfile)
parser = InventoryINIParser(loader=loader, filename=hostsfile)
processed = True
except Exception as e:
if shebang_present and not is_executable(hostsfile):
if shebang_present and not loader.is_executable(hostsfile):
myerr.append("The file %s looks like it should be an executable inventory script, but is not marked executable. " % hostsfile + \
"Perhaps you want to correct this with `chmod +x %s`?" % hostsfile)
else:

View file

@ -114,12 +114,15 @@ class Host:
def get_vars(self):
results = {}
groups = self.get_groups()
for group in sorted(groups, key=lambda g: g.depth):
results = combine_vars(results, group.get_vars())
results = combine_vars(results, self.vars)
results['inventory_hostname'] = self.name
results['inventory_hostname_short'] = self.name.split('.')[0]
results['group_names'] = sorted([ g.name for g in groups if g.name != 'all'])
results['group_names'] = sorted([ g.name for g in self.get_groups() if g.name != 'all'])
return results
def get_group_vars(self):
results = {}
groups = self.get_groups()
for group in sorted(groups, key=lambda g: g.depth):
results = combine_vars(results, group.get_vars())
return results

View file

@ -37,7 +37,8 @@ class InventoryParser(object):
with their associated hosts and variable settings.
"""
def __init__(self, filename=C.DEFAULT_HOST_LIST):
def __init__(self, loader, filename=C.DEFAULT_HOST_LIST):
self._loader = loader
self.filename = filename
# Start with an empty host list and the default 'all' and
@ -53,8 +54,14 @@ class InventoryParser(object):
# Read in the hosts, groups, and variables defined in the
# inventory file.
with open(filename) as fh:
self._parse(fh.readlines())
if loader:
(data, private) = loader._get_file_contents(filename)
data = data.split('\n')
else:
with open(filename) as fh:
data = fh.readlines()
self._parse(data)
# Finally, add all top-level groups (including 'ungrouped') as
# children of 'all'.

View file

@ -22,6 +22,7 @@ __metaclass__ = type
import copy
import json
import os
import stat
from yaml import load, YAMLError
from six import text_type
@ -56,11 +57,15 @@ class DataLoader():
ds = dl.load_from_file('/path/to/file')
'''
def __init__(self, vault_password=None):
def __init__(self):
self._basedir = '.'
self._vault_password = vault_password
self._FILE_CACHE = dict()
# initialize the vault stuff with an empty password
self.set_vault_password(None)
def set_vault_password(self, vault_password):
self._vault_password = vault_password
self._vault = VaultLib(password=vault_password)
def load(self, data, file_name='<string>', show_content=True):
@ -130,6 +135,11 @@ class DataLoader():
path = self.path_dwim(path)
return os.listdir(path)
def is_executable(self, path):
'''is the given path executable?'''
path = self.path_dwim(path)
return (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] or stat.S_IXGRP & os.stat(path)[stat.ST_MODE] or stat.S_IXOTH & os.stat(path)[stat.ST_MODE])
def _safe_load(self, stream, file_name=None):
''' Implements yaml.safe_load(), except using our custom loader class. '''
@ -249,3 +259,29 @@ class DataLoader():
return candidate
def read_vault_password_file(self, vault_password_file):
"""
Read a vault password from a file or if executable, execute the script and
retrieve password from STDOUT
"""
this_path = os.path.realpath(os.path.expanduser(vault_password_file))
if not os.path.exists(this_path):
raise AnsibleError("The vault password file %s was not found" % this_path)
if self.is_executable(this_path):
try:
# STDERR not captured to make it easier for users to prompt for input in their scripts
p = subprocess.Popen(this_path, stdout=subprocess.PIPE)
except OSError as e:
raise AnsibleError("Problem running vault password script %s (%s). If this is not a script, remove the executable bit from the file." % (' '.join(this_path), e))
stdout, stderr = p.communicate()
self.set_vault_password(stdout.strip('\r\n'))
else:
try:
f = open(this_path, "rb")
self.set_vault_password(f.read().strip())
f.close()
except (OSError, IOError) as e:
raise AnsibleError("Could not read vault password file %s: %s" % (this_path, e))

View file

@ -22,11 +22,7 @@ import stat
from time import sleep
from errno import EEXIST
__all__ = ['is_executable', 'unfrackpath']
def is_executable(path):
'''is the given path executable?'''
return (stat.S_IXUSR & os.stat(path)[stat.ST_MODE] or stat.S_IXGRP & os.stat(path)[stat.ST_MODE] or stat.S_IXOTH & os.stat(path)[stat.ST_MODE])
__all__ = ['unfrackpath']
def unfrackpath(path):
'''

View file

@ -119,11 +119,11 @@ class VariableManager:
- host_vars_files[host] (if there is a host context)
- host->get_vars (if there is a host context)
- fact_cache[host] (if there is a host context)
- vars_cache[host] (if there is a host context)
- play vars (if there is a play context)
- play vars_files (if there's no host context, ignore
file names that cannot be templated)
- task->get_vars (if there is a task context)
- vars_cache[host] (if there is a host context)
- extra vars
'''
@ -152,29 +152,34 @@ class VariableManager:
# files and then any vars from host_vars files which may apply to
# this host or the groups it belongs to
# we merge in the special 'all' group_vars first, if they exist
# we merge in vars from groups specified in the inventory (INI or script)
all_vars = combine_vars(all_vars, host.get_group_vars())
# then we merge in the special 'all' group_vars first, if they exist
if 'all' in self._group_vars_files:
data = self._preprocess_vars(self._group_vars_files['all'])
for item in data:
all_vars = combine_vars(all_vars, item)
for group in host.get_groups():
all_vars = combine_vars(all_vars, group.get_vars())
if group.name in self._group_vars_files and group.name != 'all':
data = self._preprocess_vars(self._group_vars_files[group.name])
for data in self._group_vars_files[group.name]:
data = self._preprocess_vars(data)
for item in data:
all_vars = combine_vars(all_vars, item)
# then we merge in vars from the host specified in the inventory (INI or script)
all_vars = combine_vars(all_vars, host.get_vars())
# then we merge in the host_vars/<hostname> file, if it exists
host_name = host.get_name()
if host_name in self._host_vars_files:
for data in self._host_vars_files[host_name]:
data = self._preprocess_vars(data)
for item in data:
all_vars = combine_vars(all_vars, item)
host_name = host.get_name()
if host_name in self._host_vars_files:
data = self._preprocess_vars(self._host_vars_files[host_name])
for item in data:
all_vars = combine_vars(all_vars, self._host_vars_files[host_name])
# then we merge in vars specified for this host
all_vars = combine_vars(all_vars, host.get_vars())
# next comes the facts cache and the vars cache, respectively
# finally, the facts cache for this host, if it exists
try:
host_facts = self._fact_cache.get(host.name, dict())
for k in host_facts.keys():
@ -333,7 +338,9 @@ class VariableManager:
(name, data) = self._load_inventory_file(path, loader)
if data:
self._host_vars_files[name] = data
if name not in self._host_vars_files:
self._host_vars_files[name] = []
self._host_vars_files[name].append(data)
return data
else:
return dict()
@ -347,7 +354,9 @@ class VariableManager:
(name, data) = self._load_inventory_file(path, loader)
if data:
self._group_vars_files[name] = data
if name not in self._group_vars_files:
self._group_vars_files[name] = []
self._group_vars_files[name].append(data)
return data
else:
return dict()

View file

@ -57,6 +57,10 @@ class DictDataLoader(DataLoader):
def list_directory(self, path):
return [x for x in self._known_directories]
def is_executable(self, path):
# FIXME: figure out a way to make paths return true for this
return False
def _add_known_directory(self, directory):
if directory not in self._known_directories:
self._known_directories.append(directory)

View file

@ -66,7 +66,8 @@ class TestDataLoader(unittest.TestCase):
class TestDataLoaderWithVault(unittest.TestCase):
def setUp(self):
self._loader = DataLoader(vault_password='ansible')
self._loader = DataLoader()
self._loader.set_vault_password('ansible')
def tearDown(self):
pass

View file

@ -19,11 +19,13 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from collections import defaultdict
from six import iteritems
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.inventory import Inventory
from ansible.playbook.play import Play
from ansible.vars import VariableManager
from units.mock.loader import DictDataLoader
@ -68,20 +70,27 @@ class TestVariableManager(unittest.TestCase):
fake_loader = DictDataLoader({
"host_vars/hostname1.yml": """
foo: bar
"""
""",
"other_path/host_vars/hostname1.yml": """
foo: bam
baa: bat
""",
})
v = VariableManager()
v.add_host_vars_file("host_vars/hostname1.yml", loader=fake_loader)
v.add_host_vars_file("other_path/host_vars/hostname1.yml", loader=fake_loader)
self.assertIn("hostname1", v._host_vars_files)
self.assertEqual(v._host_vars_files["hostname1"], dict(foo="bar"))
self.assertEqual(v._host_vars_files["hostname1"], [dict(foo="bar"), dict(foo="bam", baa="bat")])
mock_host = MagicMock()
mock_host.get_name.return_value = "hostname1"
mock_host.get_vars.return_value = dict()
mock_host.get_groups.return_value = ()
mock_host.get_group_vars.return_value = dict()
self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("foo"), "bar")
self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("foo"), "bam")
self.assertEqual(v.get_vars(loader=fake_loader, host=mock_host, use_cache=False).get("baa"), "bat")
def test_variable_manager_group_vars_file(self):
fake_loader = DictDataLoader({
@ -90,15 +99,19 @@ class TestVariableManager(unittest.TestCase):
""",
"group_vars/somegroup.yml": """
bam: baz
""",
"other_path/group_vars/somegroup.yml": """
baa: bat
"""
})
v = VariableManager()
v.add_group_vars_file("group_vars/all.yml", loader=fake_loader)
v.add_group_vars_file("group_vars/somegroup.yml", loader=fake_loader)
v.add_group_vars_file("other_path/group_vars/somegroup.yml", loader=fake_loader)
self.assertIn("somegroup", v._group_vars_files)
self.assertEqual(v._group_vars_files["all"], dict(foo="bar"))
self.assertEqual(v._group_vars_files["somegroup"], dict(bam="baz"))
self.assertEqual(v._group_vars_files["all"], [dict(foo="bar")])
self.assertEqual(v._group_vars_files["somegroup"], [dict(bam="baz"), dict(baa="bat")])
mock_group = MagicMock()
mock_group.name = "somegroup"
@ -109,10 +122,11 @@ class TestVariableManager(unittest.TestCase):
mock_host.get_name.return_value = "hostname1"
mock_host.get_vars.return_value = dict()
mock_host.get_groups.return_value = (mock_group,)
mock_host.get_group_vars.return_value = dict()
vars = v.get_vars(loader=fake_loader, host=mock_host, use_cache=False)
self.assertEqual(vars.get("foo"), "bar")
self.assertEqual(vars.get("bam"), "baz")
self.assertEqual(vars.get("baa"), "bat")
def test_variable_manager_play_vars(self):
fake_loader = DictDataLoader({})
@ -150,3 +164,109 @@ class TestVariableManager(unittest.TestCase):
v = VariableManager()
self.assertEqual(v.get_vars(loader=fake_loader, task=mock_task, use_cache=False).get("foo"), "bar")
def test_variable_manager_precedence(self):
'''
Tests complex variations and combinations of get_vars() with different
objects to modify the context under which variables are merged.
'''
v = VariableManager()
v._fact_cache = defaultdict(dict)
fake_loader = DictDataLoader({
# inventory1
'/etc/ansible/inventory1': """
[group2:children]
group1
[group1]
host1 host_var=host_var_from_inventory_host1
[group1:vars]
group_var = group_var_from_inventory_group1
[group2:vars]
group_var = group_var_from_inventory_group2
""",
# role defaults_only1
'/etc/ansible/roles/defaults_only1/defaults/main.yml': """
default_var: "default_var_from_defaults_only1"
host_var: "host_var_from_defaults_only1"
group_var: "group_var_from_defaults_only1"
group_var_all: "group_var_all_from_defaults_only1"
extra_var: "extra_var_from_defaults_only1"
""",
'/etc/ansible/roles/defaults_only1/tasks/main.yml': """
- debug: msg="here i am"
""",
# role defaults_only2
'/etc/ansible/roles/defaults_only2/defaults/main.yml': """
default_var: "default_var_from_defaults_only2"
host_var: "host_var_from_defaults_only2"
group_var: "group_var_from_defaults_only2"
group_var_all: "group_var_all_from_defaults_only2"
extra_var: "extra_var_from_defaults_only2"
""",
})
inv1 = Inventory(loader=fake_loader, variable_manager=v, host_list='/etc/ansible/inventory1')
inv1.set_playbook_basedir('./')
play1 = Play.load(dict(
hosts=['all'],
roles=['defaults_only1', 'defaults_only2'],
), loader=fake_loader, variable_manager=v)
# first we assert that the defaults as viewed as a whole are the merged results
# of the defaults from each role, with the last role defined "winning" when
# there is a variable naming conflict
res = v.get_vars(loader=fake_loader, play=play1)
self.assertEqual(res['default_var'], 'default_var_from_defaults_only2')
# next, we assert that when vars are viewed from the context of a task within a
# role, that task will see its own role defaults before any other role's
blocks = play1.compile()
task = blocks[1].block[0]
res = v.get_vars(loader=fake_loader, play=play1, task=task)
self.assertEqual(res['default_var'], 'default_var_from_defaults_only1')
# next we assert the precendence of inventory variables
v.set_inventory(inv1)
h1 = inv1.get_host('host1')
res = v.get_vars(loader=fake_loader, play=play1, host=h1)
self.assertEqual(res['group_var'], 'group_var_from_inventory_group1')
self.assertEqual(res['host_var'], 'host_var_from_inventory_host1')
# next we test with group_vars/ files loaded
fake_loader.push("/etc/ansible/group_vars/all", """
group_var_all: group_var_all_from_group_vars_all
""")
fake_loader.push("/etc/ansible/group_vars/group1", """
group_var: group_var_from_group_vars_group1
""")
fake_loader.push("/etc/ansible/group_vars/group3", """
# this is a dummy, which should not be used anywhere
group_var: group_var_from_group_vars_group3
""")
fake_loader.push("/etc/ansible/host_vars/host1", """
host_var: host_var_from_host_vars_host1
""")
v.add_group_vars_file("/etc/ansible/group_vars/all", loader=fake_loader)
v.add_group_vars_file("/etc/ansible/group_vars/group1", loader=fake_loader)
v.add_group_vars_file("/etc/ansible/group_vars/group2", loader=fake_loader)
v.add_host_vars_file("/etc/ansible/host_vars/host1", loader=fake_loader)
res = v.get_vars(loader=fake_loader, play=play1, host=h1)
self.assertEqual(res['group_var'], 'group_var_from_group_vars_group1')
self.assertEqual(res['group_var_all'], 'group_var_all_from_group_vars_all')
self.assertEqual(res['host_var'], 'host_var_from_host_vars_host1')
# add in the fact cache
v._fact_cache['host1'] = dict(fact_cache_var="fact_cache_var_from_fact_cache")
res = v.get_vars(loader=fake_loader, play=play1, host=h1)
self.assertEqual(res['fact_cache_var'], 'fact_cache_var_from_fact_cache')