* PluginLoader class will now be more selective about loading some
  plugin classes, if a required base class is specified (used to avoid
  loading v1 plugins that have changed significantly in their apis)
* Added ability for the connection info class to read values from a
  given hosts variables, to support "magic" variables
* Added some more magic variables to the VariableManager output
* Fixed a bug in the ActionBase class, where the module configuration
  code was not correctly handling unicode
This commit is contained in:
James Cammarata 2015-05-11 11:22:41 -05:00
parent f141ec9671
commit daf533c80e
5 changed files with 75 additions and 36 deletions

View file

@ -29,6 +29,20 @@ from ansible.errors import AnsibleError
__all__ = ['ConnectionInformation'] __all__ = ['ConnectionInformation']
# the magic variable mapping dictionary below is used to translate
# host/inventory variables to fields in the ConnectionInformation
# object. The dictionary values are tuples, to account for aliases
# in variable names.
MAGIC_VARIABLE_MAPPING = dict(
connection = ('ansible_connection',),
remote_addr = ('ansible_ssh_host', 'ansible_host'),
remote_user = ('ansible_ssh_user', 'ansible_user'),
port = ('ansible_ssh_port', 'ansible_port'),
password = ('ansible_ssh_pass', 'ansible_password'),
private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'),
shell = ('ansible_shell_type',),
)
class ConnectionInformation: class ConnectionInformation:
@ -51,6 +65,7 @@ class ConnectionInformation:
self.port = None self.port = None
self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE self.private_key_file = C.DEFAULT_PRIVATE_KEY_FILE
self.timeout = C.DEFAULT_TIMEOUT self.timeout = C.DEFAULT_TIMEOUT
self.shell = None
# privilege escalation # privilege escalation
self.become = None self.become = None
@ -170,7 +185,7 @@ class ConnectionInformation:
else: else:
setattr(self, field, value) setattr(self, field, value)
def set_task_override(self, task): def set_task_and_host_override(self, task, host):
''' '''
Sets attributes from the task if they are set, which will override Sets attributes from the task if they are set, which will override
those from the play. those from the play.
@ -179,12 +194,22 @@ class ConnectionInformation:
new_info = ConnectionInformation() new_info = ConnectionInformation()
new_info.copy(self) new_info.copy(self)
# loop through a subset of attributes on the task object and set
# connection fields based on their values
for attr in ('connection', 'remote_user', 'become', 'become_user', 'become_pass', 'become_method', 'environment', 'no_log'): for attr in ('connection', 'remote_user', 'become', 'become_user', 'become_pass', 'become_method', 'environment', 'no_log'):
if hasattr(task, attr): if hasattr(task, attr):
attr_val = getattr(task, attr) attr_val = getattr(task, attr)
if attr_val: if attr_val:
setattr(new_info, attr, attr_val) setattr(new_info, attr, attr_val)
# finally, use the MAGIC_VARIABLE_MAPPING dictionary to update this
# connection info object with 'magic' variables from inventory
variables = host.get_vars()
for (attr, variable_names) in MAGIC_VARIABLE_MAPPING.iteritems():
for variable_name in variable_names:
if variable_name in variables:
setattr(new_info, attr, variables[variable_name])
return new_info return new_info
def make_become_cmd(self, cmd, executable, become_settings=None): def make_become_cmd(self, cmd, executable, become_settings=None):

View file

@ -111,7 +111,7 @@ class WorkerProcess(multiprocessing.Process):
# apply the given task's information to the connection info, # apply the given task's information to the connection info,
# which may override some fields already set by the play or # which may override some fields already set by the play or
# the options specified on the command line # the options specified on the command line
new_connection_info = connection_info.set_task_override(task) new_connection_info = connection_info.set_task_and_host_override(task=task, host=host)
# execute the task and build a TaskResult from the result # execute the task and build a TaskResult from the result
debug("running TaskExecutor() for %s/%s" % (host, task)) debug("running TaskExecutor() for %s/%s" % (host, task))

View file

@ -55,9 +55,10 @@ class PluginLoader:
The first match is used. The first match is used.
''' '''
def __init__(self, class_name, package, config, subdir, aliases={}): def __init__(self, class_name, package, config, subdir, aliases={}, required_base_class=None):
self.class_name = class_name self.class_name = class_name
self.base_class = required_base_class
self.package = package self.package = package
self.config = config self.config = config
self.subdir = subdir self.subdir = subdir
@ -87,11 +88,12 @@ class PluginLoader:
config = data.get('config') config = data.get('config')
subdir = data.get('subdir') subdir = data.get('subdir')
aliases = data.get('aliases') aliases = data.get('aliases')
base_class = data.get('base_class')
PATH_CACHE[class_name] = data.get('PATH_CACHE') PATH_CACHE[class_name] = data.get('PATH_CACHE')
PLUGIN_PATH_CACHE[class_name] = data.get('PLUGIN_PATH_CACHE') PLUGIN_PATH_CACHE[class_name] = data.get('PLUGIN_PATH_CACHE')
self.__init__(class_name, package, config, subdir, aliases) self.__init__(class_name, package, config, subdir, aliases, base_class)
self._extra_dirs = data.get('_extra_dirs', []) self._extra_dirs = data.get('_extra_dirs', [])
self._searched_paths = data.get('_searched_paths', set()) self._searched_paths = data.get('_searched_paths', set())
@ -102,6 +104,7 @@ class PluginLoader:
return dict( return dict(
class_name = self.class_name, class_name = self.class_name,
base_class = self.base_class,
package = self.package, package = self.package,
config = self.config, config = self.config,
subdir = self.subdir, subdir = self.subdir,
@ -268,9 +271,13 @@ class PluginLoader:
self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path) self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path)
if kwargs.get('class_only', False): if kwargs.get('class_only', False):
return getattr(self._module_cache[path], self.class_name) obj = getattr(self._module_cache[path], self.class_name)
else: else:
return getattr(self._module_cache[path], self.class_name)(*args, **kwargs) obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]:
return None
return obj
def all(self, *args, **kwargs): def all(self, *args, **kwargs):
''' instantiates all plugins with the same arguments ''' ''' instantiates all plugins with the same arguments '''
@ -291,6 +298,9 @@ class PluginLoader:
else: else:
obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs) obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]:
continue
# set extra info on the module, in case we want it later # set extra info on the module, in case we want it later
setattr(obj, '_original_path', path) setattr(obj, '_original_path', path)
yield obj yield obj
@ -299,21 +309,22 @@ action_loader = PluginLoader(
'ActionModule', 'ActionModule',
'ansible.plugins.action', 'ansible.plugins.action',
C.DEFAULT_ACTION_PLUGIN_PATH, C.DEFAULT_ACTION_PLUGIN_PATH,
'action_plugins' 'action_plugins',
required_base_class='ActionBase',
) )
cache_loader = PluginLoader( cache_loader = PluginLoader(
'CacheModule', 'CacheModule',
'ansible.plugins.cache', 'ansible.plugins.cache',
C.DEFAULT_CACHE_PLUGIN_PATH, C.DEFAULT_CACHE_PLUGIN_PATH,
'cache_plugins' 'cache_plugins',
) )
callback_loader = PluginLoader( callback_loader = PluginLoader(
'CallbackModule', 'CallbackModule',
'ansible.plugins.callback', 'ansible.plugins.callback',
C.DEFAULT_CALLBACK_PLUGIN_PATH, C.DEFAULT_CALLBACK_PLUGIN_PATH,
'callback_plugins' 'callback_plugins',
) )
connection_loader = PluginLoader( connection_loader = PluginLoader(
@ -321,7 +332,8 @@ connection_loader = PluginLoader(
'ansible.plugins.connections', 'ansible.plugins.connections',
C.DEFAULT_CONNECTION_PLUGIN_PATH, C.DEFAULT_CONNECTION_PLUGIN_PATH,
'connection_plugins', 'connection_plugins',
aliases={'paramiko': 'paramiko_ssh'} aliases={'paramiko': 'paramiko_ssh'},
required_base_class='ConnectionBase',
) )
shell_loader = PluginLoader( shell_loader = PluginLoader(
@ -335,28 +347,29 @@ module_loader = PluginLoader(
'', '',
'ansible.modules', 'ansible.modules',
C.DEFAULT_MODULE_PATH, C.DEFAULT_MODULE_PATH,
'library' 'library',
) )
lookup_loader = PluginLoader( lookup_loader = PluginLoader(
'LookupModule', 'LookupModule',
'ansible.plugins.lookup', 'ansible.plugins.lookup',
C.DEFAULT_LOOKUP_PLUGIN_PATH, C.DEFAULT_LOOKUP_PLUGIN_PATH,
'lookup_plugins' 'lookup_plugins',
required_base_class='LookupBase',
) )
vars_loader = PluginLoader( vars_loader = PluginLoader(
'VarsModule', 'VarsModule',
'ansible.plugins.vars', 'ansible.plugins.vars',
C.DEFAULT_VARS_PLUGIN_PATH, C.DEFAULT_VARS_PLUGIN_PATH,
'vars_plugins' 'vars_plugins',
) )
filter_loader = PluginLoader( filter_loader = PluginLoader(
'FilterModule', 'FilterModule',
'ansible.plugins.filter', 'ansible.plugins.filter',
C.DEFAULT_FILTER_PLUGIN_PATH, C.DEFAULT_FILTER_PLUGIN_PATH,
'filter_plugins' 'filter_plugins',
) )
fragment_loader = PluginLoader( fragment_loader = PluginLoader(
@ -371,4 +384,5 @@ strategy_loader = PluginLoader(
'ansible.plugins.strategies', 'ansible.plugins.strategies',
None, None,
'strategy_plugins', 'strategy_plugins',
required_base_class='StrategyBase',
) )

View file

@ -34,6 +34,7 @@ from ansible.parsing.utils.jsonify import jsonify
from ansible.plugins import shell_loader from ansible.plugins import shell_loader
from ansible.utils.debug import debug from ansible.utils.debug import debug
from ansible.utils.unicode import to_bytes
class ActionBase: class ActionBase:
@ -51,21 +52,21 @@ class ActionBase:
self._loader = loader self._loader = loader
self._templar = templar self._templar = templar
self._shared_loader_obj = shared_loader_obj self._shared_loader_obj = shared_loader_obj
self._shell = self.get_shell()
# load the shell plugin for this action/connection
if self._connection_info.shell:
shell_type = self._connection_info.shell
elif hasattr(connection, '_shell'):
shell_type = getattr(connection, '_shell')
else:
shell_type = os.path.basename(C.DEFAULT_EXECUTABLE)
self._shell = shell_loader.get(shell_type)
if not self._shell:
raise AnsibleError("Invalid shell type specified (%s), or the plugin for that shell type is missing." % shell_type)
self._supports_check_mode = True self._supports_check_mode = True
def get_shell(self):
if hasattr(self._connection, '_shell'):
shell_plugin = getattr(self._connection, '_shell', '')
else:
shell_plugin = shell_loader.get(os.path.basename(C.DEFAULT_EXECUTABLE))
if shell_plugin is None:
shell_plugin = shell_loader.get('sh')
return shell_plugin
def _configure_module(self, module_name, module_args): def _configure_module(self, module_name, module_args):
''' '''
Handles the loading and templating of the module code through the Handles the loading and templating of the module code through the
@ -201,18 +202,13 @@ class ActionBase:
Copies the module data out to the temporary module path. Copies the module data out to the temporary module path.
''' '''
if type(data) == dict: if isinstance(data, dict):
data = jsonify(data) data = jsonify(data)
afd, afile = tempfile.mkstemp() afd, afile = tempfile.mkstemp()
afo = os.fdopen(afd, 'w') afo = os.fdopen(afd, 'w')
try: try:
# FIXME: is this still necessary? data = to_bytes(data, errors='strict')
#if not isinstance(data, unicode):
# #ensure the data is valid UTF-8
# data = data.decode('utf-8')
#else:
# data = data.encode('utf-8')
afo.write(data) afo.write(data)
except Exception as e: except Exception as e:
#raise AnsibleError("failure encoding into utf-8: %s" % str(e)) #raise AnsibleError("failure encoding into utf-8: %s" % str(e))

View file

@ -212,7 +212,11 @@ class VariableManager:
# FIXME: make sure all special vars are here # FIXME: make sure all special vars are here
# Finally, we create special vars # Finally, we create special vars
if host and self._inventory is not None:
if host:
all_vars['groups'] = [group.name for group in host.get_groups()]
if self._inventory is not None:
hostvars = HostVars(vars_manager=self, inventory=self._inventory, loader=loader) hostvars = HostVars(vars_manager=self, inventory=self._inventory, loader=loader)
all_vars['hostvars'] = hostvars all_vars['hostvars'] = hostvars