diff --git a/lib/ansible/parsing/mod_args.py b/lib/ansible/parsing/mod_args.py index 5ba0faac70f..a6ccaa0a565 100644 --- a/lib/ansible/parsing/mod_args.py +++ b/lib/ansible/parsing/mod_args.py @@ -253,16 +253,16 @@ class ModuleArgsParser: action, args = self._normalize_parameters(thing, additional_args=additional_args) # local_action + local_action = False if 'local_action' in self._task_ds: # local_action is similar but also implies a connection='local' if action is not None: raise AnsibleParserError("action and local_action are mutually exclusive", obj=self._task_ds) thing = self._task_ds.get('local_action', '') connection = 'local' + local_action = True action, args = self._normalize_parameters(thing, additional_args=additional_args) - # module: is the more new-style invocation - # walk the input dictionary to see we recognize a module name for (item, value) in iteritems(self._task_ds): if item in module_loader or item == 'meta' or item == 'include': @@ -287,4 +287,8 @@ class ModuleArgsParser: # shell modules require special handling (action, args) = self._handle_shell_weirdness(action, args) + # now add the local action flag to the args, if it was set + if local_action: + args['_local_action'] = local_action + return (action, args, connection) diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py index 55d7d99b5a4..1706ef413c1 100644 --- a/lib/ansible/playbook/play_context.py +++ b/lib/ansible/playbook/play_context.py @@ -310,7 +310,7 @@ class PlayContext(Base): if attr_val is not None: setattr(new_info, attr, attr_val) - # finally, use the MAGIC_VARIABLE_MAPPING dictionary to update this + # next, use the MAGIC_VARIABLE_MAPPING dictionary to update this # connection info object with 'magic' variables from the variable list for (attr, variable_names) in iteritems(MAGIC_VARIABLE_MAPPING): for variable_name in variable_names: @@ -328,6 +328,12 @@ class PlayContext(Base): elif new_info.become_method == 'su' and new_info.su_pass: setattr(new_info, 'become_pass', new_info.su_pass) + # finally, in the special instance that the task was specified + # as a local action, override the connection in case it was changed + # during some other step in the process + if task._local_action: + setattr(new_info, 'connection', 'local') + return new_info def make_become_cmd(self, cmd, executable=None): diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py index fef629c60f1..8f7e1b77156 100644 --- a/lib/ansible/playbook/task.py +++ b/lib/ansible/playbook/task.py @@ -95,6 +95,10 @@ class Task(Base, Conditional, Taggable, Become): self._role = role self._task_include = task_include + # special flag for local_action: tasks, to make sure their + # connection type of local isn't overridden incorrectly + self._local_action = False + super(Task, self).__init__() def get_name(self): @@ -130,6 +134,16 @@ class Task(Base, Conditional, Taggable, Become): t = Task(block=block, role=role, task_include=task_include) return t.load_data(data, variable_manager=variable_manager, loader=loader) + def load_data(self, ds, variable_manager=None, loader=None): + ''' + We override load_data for tasks so that we can pull special flags + out of the task args and set them internaly only so the user never + sees them. + ''' + t = super(Task, self).load_data(ds=ds, variable_manager=variable_manager, loader=loader) + t._local_action = t.args.pop('_local_action', False) + return t + def __repr__(self): ''' returns a human readable representation of the task ''' return "TASK: %s" % self.get_name() @@ -260,6 +274,7 @@ class Task(Base, Conditional, Taggable, Become): def copy(self, exclude_block=False): new_me = super(Task, self).copy() + new_me._local_action = self._local_action new_me._block = None if self._block and not exclude_block: @@ -277,6 +292,7 @@ class Task(Base, Conditional, Taggable, Become): def serialize(self): data = super(Task, self).serialize() + data['_local_action'] = self._local_action if self._block: data['block'] = self._block.serialize() @@ -295,6 +311,7 @@ class Task(Base, Conditional, Taggable, Become): #from ansible.playbook.task_include import TaskInclude block_data = data.get('block') + self._local_action = data.get('_local_action', False) if block_data: b = Block() diff --git a/test/units/parsing/test_mod_args.py b/test/units/parsing/test_mod_args.py index bce31d6f1f8..1d5f817cb07 100644 --- a/test/units/parsing/test_mod_args.py +++ b/test/units/parsing/test_mod_args.py @@ -112,7 +112,7 @@ class TestModArgsDwim(unittest.TestCase): mod, args, connection = m.parse() self._debug(mod, args, connection) self.assertEqual(mod, 'copy') - self.assertEqual(args, dict(src='a', dest='b')) + self.assertEqual(args, dict(src='a', dest='b', _local_action=True)) self.assertIs(connection, 'local') def test_multiple_actions(self):