Move persistent connections to only use registered variables (#45616)
* Try to intuit proper plugins to send to ansible-connection * Move sub-plugins to init so that vars will be populated in executor * Fix connection unit tests
This commit is contained in:
parent
86c48205c4
commit
406b59aeba
10 changed files with 98 additions and 109 deletions
|
@ -842,17 +842,29 @@ class TaskExecutor:
|
|||
self._play_context.timeout = connection.get_option('persistent_command_timeout')
|
||||
display.vvvv('attempting to start connection', host=self._play_context.remote_addr)
|
||||
display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr)
|
||||
# We don't need to send the entire contents of variables to ansible-connection
|
||||
filtered_vars = dict(
|
||||
(key, value) for key, value in variables.items()
|
||||
if key.startswith('ansible') and key != 'ansible_failed_task'
|
||||
)
|
||||
socket_path = self._start_connection(filtered_vars)
|
||||
|
||||
options = self._get_persistent_connection_options(connection, variables, templar)
|
||||
socket_path = self._start_connection(options)
|
||||
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
|
||||
setattr(connection, '_socket_path', socket_path)
|
||||
|
||||
return connection
|
||||
|
||||
def _get_persistent_connection_options(self, connection, variables, templar):
|
||||
final_vars = combine_vars(variables, variables.get('ansible_delegated_vars', dict()).get(self._task.delegate_to, dict()))
|
||||
|
||||
option_vars = C.config.get_plugin_vars('connection', connection._load_name)
|
||||
for plugin in connection._sub_plugins:
|
||||
if plugin['type'] != 'external':
|
||||
option_vars.extend(C.config.get_plugin_vars(plugin['type'], plugin['name']))
|
||||
|
||||
options = {}
|
||||
for k in option_vars:
|
||||
if k in final_vars:
|
||||
options[k] = templar.template(final_vars[k])
|
||||
|
||||
return options
|
||||
|
||||
def _set_connection_options(self, variables, templar):
|
||||
|
||||
# Keep the pre-delegate values for these keys
|
||||
|
|
|
@ -300,7 +300,7 @@ class NetworkConnectionBase(ConnectionBase):
|
|||
self._local = connection_loader.get('local', play_context, '/dev/null')
|
||||
self._local.set_options()
|
||||
|
||||
self._implementation_plugins = []
|
||||
self._sub_plugins = []
|
||||
self._cached_variables = (None, None, None)
|
||||
|
||||
# reconstruct the socket_path and set instance values accordingly
|
||||
|
@ -312,16 +312,12 @@ class NetworkConnectionBase(ConnectionBase):
|
|||
return self.__dict__[name]
|
||||
except KeyError:
|
||||
if not name.startswith('_'):
|
||||
for plugin in self._implementation_plugins:
|
||||
method = getattr(plugin, name, None)
|
||||
for plugin in self._sub_plugins:
|
||||
method = getattr(plugin['obj'], name, None)
|
||||
if method is not None:
|
||||
return method
|
||||
raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
|
||||
|
||||
def _connect(self):
|
||||
self.set_implementation_plugin_options(*self._cached_variables)
|
||||
self._cached_variables = (None, None, None)
|
||||
|
||||
def exec_command(self, cmd, in_data=None, sudoable=True):
|
||||
return self._local.exec_command(cmd, in_data, sudoable)
|
||||
|
||||
|
@ -345,25 +341,16 @@ class NetworkConnectionBase(ConnectionBase):
|
|||
def close(self):
|
||||
if self._connected:
|
||||
self._connected = False
|
||||
self._implementation_plugins = []
|
||||
|
||||
def set_options(self, task_keys=None, var_options=None, direct=None):
|
||||
super(NetworkConnectionBase, self).set_options(task_keys=task_keys, var_options=var_options, direct=direct)
|
||||
|
||||
if self._implementation_plugins:
|
||||
self.set_implementation_plugin_options(task_keys, var_options, direct)
|
||||
else:
|
||||
self._cached_variables = (task_keys, var_options, direct)
|
||||
|
||||
def set_implementation_plugin_options(self, task_keys=None, var_options=None, direct=None):
|
||||
'''
|
||||
initialize implementation plugin options
|
||||
'''
|
||||
for plugin in self._implementation_plugins:
|
||||
try:
|
||||
plugin.set_options(task_keys=task_keys, var_options=var_options, direct=direct)
|
||||
except AttributeError:
|
||||
pass
|
||||
for plugin in self._sub_plugins:
|
||||
if plugin['type'] != 'external':
|
||||
try:
|
||||
plugin['obj'].set_options(task_keys=task_keys, var_options=var_options, direct=direct)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _update_connection_state(self):
|
||||
'''
|
||||
|
|
|
@ -176,7 +176,22 @@ class Connection(NetworkConnectionBase):
|
|||
self._url = None
|
||||
self._auth = None
|
||||
|
||||
if not self._network_os:
|
||||
if self._network_os:
|
||||
|
||||
self.httpapi = httpapi_loader.get(self._network_os, self)
|
||||
if self.httpapi:
|
||||
self._sub_plugins.append({'type': 'httpapi', 'name': self._network_os, 'obj': self.httpapi})
|
||||
display.vvvv('loaded API plugin for network_os %s' % self._network_os)
|
||||
else:
|
||||
raise AnsibleConnectionFailure('unable to load API plugin for network_os %s' % self._network_os)
|
||||
|
||||
self.cliconf = cliconf_loader.get(self._network_os, self)
|
||||
if self.cliconf:
|
||||
self._sub_plugins.append({'type': 'cliconf', 'name': self._network_os, 'obj': self.cliconf})
|
||||
display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os)
|
||||
else:
|
||||
display.vvvv('unable to load cliconf for network_os %s' % self._network_os)
|
||||
else:
|
||||
raise AnsibleConnectionFailure(
|
||||
'Unable to automatically determine host network os. Please '
|
||||
'manually configure ansible_network_os value for this host'
|
||||
|
@ -211,24 +226,8 @@ class Connection(NetworkConnectionBase):
|
|||
port = self.get_option('port') or (443 if protocol == 'https' else 80)
|
||||
self._url = '%s://%s:%s' % (protocol, host, port)
|
||||
|
||||
httpapi = httpapi_loader.get(self._network_os, self)
|
||||
if httpapi:
|
||||
display.vvvv('loaded API plugin for network_os %s' % self._network_os, host=host)
|
||||
self._implementation_plugins.append(httpapi)
|
||||
else:
|
||||
raise AnsibleConnectionFailure('unable to load API plugin for network_os %s' % self._network_os)
|
||||
|
||||
cliconf = cliconf_loader.get(self._network_os, self)
|
||||
if cliconf:
|
||||
display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os, host=host)
|
||||
self._implementation_plugins.append(cliconf)
|
||||
else:
|
||||
display.vvvv('unable to load cliconf for network_os %s' % self._network_os)
|
||||
|
||||
super(Connection, self)._connect()
|
||||
|
||||
httpapi.set_become(self._play_context)
|
||||
httpapi.login(self.get_option('remote_user'), self.get_option('password'))
|
||||
self.httpapi.set_become(self._play_context)
|
||||
self.httpapi.login(self.get_option('remote_user'), self.get_option('password'))
|
||||
|
||||
self._connected = True
|
||||
|
||||
|
|
|
@ -186,7 +186,7 @@ class Connection(NetworkConnectionBase):
|
|||
|
||||
self.napalm.open()
|
||||
|
||||
self._implementation_plugins.append(self.napalm)
|
||||
self._sub_plugins.append({'type': 'external', 'name': 'napalm', 'obj': self.napalm})
|
||||
display.vvvv('created napalm device for network_os %s' % self._network_os, host=host)
|
||||
self._connected = True
|
||||
|
||||
|
|
|
@ -217,6 +217,15 @@ class Connection(NetworkConnectionBase):
|
|||
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
|
||||
|
||||
self._network_os = self._network_os or 'default'
|
||||
|
||||
netconf = netconf_loader.get(self._network_os, self)
|
||||
if netconf:
|
||||
self._sub_plugins.append({'type': 'netconf', 'name': self._network_os, 'obj': netconf})
|
||||
display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True)
|
||||
else:
|
||||
netconf = netconf_loader.get("default", self)
|
||||
self._sub_plugins.append({'type': 'netconf', 'name': 'default', 'obj': netconf})
|
||||
display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % self._network_os)
|
||||
display.display('network_os is set to %s' % self._network_os, log_only=True)
|
||||
|
||||
self._manager = None
|
||||
|
@ -246,8 +255,6 @@ class Connection(NetworkConnectionBase):
|
|||
return super(Connection, self).exec_command(cmd, in_data, sudoable)
|
||||
|
||||
def _connect(self):
|
||||
super(Connection, self)._connect()
|
||||
|
||||
display.display('ssh connection done, starting ncclient', log_only=True)
|
||||
|
||||
allow_agent = True
|
||||
|
@ -300,14 +307,6 @@ class Connection(NetworkConnectionBase):
|
|||
|
||||
self._connected = True
|
||||
|
||||
netconf = netconf_loader.get(self._network_os, self)
|
||||
if netconf:
|
||||
display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True)
|
||||
else:
|
||||
netconf = netconf_loader.get("default", self)
|
||||
display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % self._network_os)
|
||||
self._implementation_plugins.append(netconf)
|
||||
|
||||
super(Connection, self)._connect()
|
||||
|
||||
return 0, to_bytes(self._manager.session_id, errors='surrogate_or_strict'), b''
|
||||
|
|
|
@ -208,6 +208,21 @@ class Connection(NetworkConnectionBase):
|
|||
if self._play_context.verbosity > 3:
|
||||
logging.getLogger('paramiko').setLevel(logging.DEBUG)
|
||||
|
||||
if self._network_os:
|
||||
|
||||
self.cliconf = cliconf_loader.get(self._network_os, self)
|
||||
if self.cliconf:
|
||||
display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os)
|
||||
self._sub_plugins.append({'type': 'cliconf', 'name': self._network_os, 'obj': self.cliconf})
|
||||
else:
|
||||
display.vvvv('unable to load cliconf for network_os %s' % self._network_os)
|
||||
else:
|
||||
raise AnsibleConnectionFailure(
|
||||
'Unable to automatically determine host network os. Please '
|
||||
'manually configure ansible_network_os value for this host'
|
||||
)
|
||||
display.display('network_os is set to %s' % self._network_os, log_only=True)
|
||||
|
||||
def _get_log_channel(self):
|
||||
name = "p=%s u=%s | " % (os.getpid(), getpass.getuser())
|
||||
name += "paramiko [%s]" % self._play_context.remote_addr
|
||||
|
@ -270,13 +285,6 @@ class Connection(NetworkConnectionBase):
|
|||
Connects to the remote device and starts the terminal
|
||||
'''
|
||||
if not self.connected:
|
||||
if not self._network_os:
|
||||
raise AnsibleConnectionFailure(
|
||||
'Unable to automatically determine host network os. Please '
|
||||
'manually configure ansible_network_os value for this host'
|
||||
)
|
||||
display.display('network_os is set to %s' % self._network_os, log_only=True)
|
||||
|
||||
self.paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null')
|
||||
self.paramiko_conn._set_log_channel(self._get_log_channel())
|
||||
self.paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)})
|
||||
|
@ -295,15 +303,6 @@ class Connection(NetworkConnectionBase):
|
|||
|
||||
display.vvvv('loaded terminal plugin for network_os %s' % self._network_os, host=host)
|
||||
|
||||
self.cliconf = cliconf_loader.get(self._network_os, self)
|
||||
if self.cliconf:
|
||||
display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os, host=host)
|
||||
self._implementation_plugins.append(self.cliconf)
|
||||
else:
|
||||
display.vvvv('unable to load cliconf for network_os %s' % self._network_os)
|
||||
|
||||
super(Connection, self)._connect()
|
||||
|
||||
self.receive(prompts=self._terminal.terminal_initial_prompt, answer=self._terminal.terminal_initial_answer,
|
||||
newline=self._terminal.terminal_inital_prompt_newline)
|
||||
|
||||
|
|
|
@ -102,7 +102,10 @@ class NetconfBase(AnsiblePlugin):
|
|||
|
||||
def __init__(self, connection):
|
||||
self._connection = connection
|
||||
self.m = self._connection._manager
|
||||
|
||||
@property
|
||||
def m(self):
|
||||
return self._connection._manager
|
||||
|
||||
@ensure_connected
|
||||
def rpc(self, name):
|
||||
|
|
|
@ -41,6 +41,7 @@ from ansible.plugins.connection.ssh import Connection as SSHConnection
|
|||
from ansible.plugins.connection.docker import Connection as DockerConnection
|
||||
# from ansible.plugins.connection.winrm import Connection as WinRmConnection
|
||||
from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection
|
||||
from ansible.plugins.connection.httpapi import Connection as HttpapiConnection
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
|
@ -162,11 +163,16 @@ class TestConnectionBaseClass(unittest.TestCase):
|
|||
# self.assertIsInstance(WinRmConnection(), WinRmConnection)
|
||||
|
||||
def test_network_cli_connection_module(self):
|
||||
self.play_context.network_os = 'eos'
|
||||
self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection)
|
||||
|
||||
def test_netconf_connection_module(self):
|
||||
self.assertIsInstance(NetconfConnection(self.play_context, self.in_stream), NetconfConnection)
|
||||
|
||||
def test_httpapi_connection_module(self):
|
||||
self.play_context.network_os = 'eos'
|
||||
self.assertIsInstance(HttpapiConnection(self.play_context, self.in_stream), HttpapiConnection)
|
||||
|
||||
def test_check_password_prompt(self):
|
||||
local = (
|
||||
b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n'
|
||||
|
|
|
@ -58,9 +58,7 @@ class TestNetconfConnectionClass(unittest.TestCase):
|
|||
|
||||
def test_netconf_init(self):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
|
||||
conn = netconf.Connection(pc, new_stdin)
|
||||
conn = connection_loader.get('netconf', pc, '/dev/null')
|
||||
|
||||
self.assertEqual('default', conn._network_os)
|
||||
self.assertIsNone(conn._manager)
|
||||
|
@ -69,14 +67,11 @@ class TestNetconfConnectionClass(unittest.TestCase):
|
|||
@patch("ansible.plugins.connection.netconf.netconf_loader")
|
||||
def test_netconf__connect(self, mock_netconf_loader):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
|
||||
conn = connection_loader.get('netconf', pc, new_stdin)
|
||||
conn = connection_loader.get('netconf', pc, '/dev/null')
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.session_id = '123456789'
|
||||
netconf.manager.connect = MagicMock(return_value=mock_manager)
|
||||
conn._play_context.network_os = 'default'
|
||||
|
||||
rc, out, err = conn._connect()
|
||||
|
||||
|
@ -87,9 +82,8 @@ class TestNetconfConnectionClass(unittest.TestCase):
|
|||
|
||||
def test_netconf_exec_command(self):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
conn = connection_loader.get('netconf', pc, '/dev/null')
|
||||
|
||||
conn = netconf.Connection(pc, new_stdin)
|
||||
conn._connected = True
|
||||
|
||||
mock_reply = MagicMock(name='reply')
|
||||
|
@ -105,9 +99,8 @@ class TestNetconfConnectionClass(unittest.TestCase):
|
|||
|
||||
def test_netconf_exec_command_invalid_request(self):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
conn = connection_loader.get('netconf', pc, '/dev/null')
|
||||
|
||||
conn = netconf.Connection(pc, new_stdin)
|
||||
conn._connected = True
|
||||
|
||||
mock_manager = MagicMock(name='self._manager')
|
||||
|
|
|
@ -30,7 +30,6 @@ from ansible.compat.tests.mock import patch, MagicMock
|
|||
|
||||
from ansible.errors import AnsibleConnectionFailure
|
||||
from ansible.playbook.play_context import PlayContext
|
||||
from ansible.plugins.connection import network_cli
|
||||
from ansible.plugins.loader import connection_loader
|
||||
|
||||
|
||||
|
@ -39,39 +38,30 @@ class TestConnectionClass(unittest.TestCase):
|
|||
@patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
|
||||
def test_network_cli__connect_error(self, mocked_super):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
|
||||
pc.network_os = 'ios'
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
|
||||
conn.ssh = MagicMock()
|
||||
conn.receive = MagicMock()
|
||||
conn._terminal = MagicMock()
|
||||
pc.network_os = None
|
||||
conn._network_os = 'does not exist'
|
||||
|
||||
self.assertRaises(AnsibleConnectionFailure, conn._connect)
|
||||
|
||||
@patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
|
||||
def test_network_cli__invalid_os(self, mocked_super):
|
||||
def test_network_cli__invalid_os(self):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
conn.ssh = MagicMock()
|
||||
conn.receive = MagicMock()
|
||||
conn._terminal = MagicMock()
|
||||
pc.network_os = None
|
||||
self.assertRaises(AnsibleConnectionFailure, conn._connect)
|
||||
|
||||
self.assertRaises(AnsibleConnectionFailure, connection_loader.get, 'network_cli', pc, '/dev/null')
|
||||
|
||||
@patch("ansible.plugins.connection.network_cli.terminal_loader")
|
||||
@patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
|
||||
def test_network_cli__connect(self, mocked_super, mocked_terminal_loader):
|
||||
pc = PlayContext()
|
||||
pc.network_os = 'ios'
|
||||
new_stdin = StringIO()
|
||||
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
|
||||
conn.ssh = MagicMock()
|
||||
conn.receive = MagicMock()
|
||||
conn._terminal = MagicMock()
|
||||
|
||||
conn._connect()
|
||||
self.assertTrue(conn._terminal.on_open_shell.called)
|
||||
|
@ -88,8 +78,8 @@ class TestConnectionClass(unittest.TestCase):
|
|||
@patch("ansible.plugins.connection.paramiko_ssh.Connection.close")
|
||||
def test_network_cli_close(self, mocked_super):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
conn = network_cli.Connection(pc, new_stdin)
|
||||
pc.network_os = 'ios'
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
|
||||
terminal = MagicMock(supports_multiplexing=False)
|
||||
conn._terminal = terminal
|
||||
|
@ -105,8 +95,8 @@ class TestConnectionClass(unittest.TestCase):
|
|||
@patch("ansible.plugins.connection.paramiko_ssh.Connection._connect")
|
||||
def test_network_cli_exec_command(self, mocked_super):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
conn = network_cli.Connection(pc, new_stdin)
|
||||
pc.network_os = 'ios'
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
|
||||
mock_send = MagicMock(return_value=b'command response')
|
||||
conn.send = mock_send
|
||||
|
@ -124,8 +114,9 @@ class TestConnectionClass(unittest.TestCase):
|
|||
|
||||
def test_network_cli_send(self):
|
||||
pc = PlayContext()
|
||||
new_stdin = StringIO()
|
||||
conn = network_cli.Connection(pc, new_stdin)
|
||||
pc.network_os = 'ios'
|
||||
conn = connection_loader.get('network_cli', pc, '/dev/null')
|
||||
|
||||
mock__terminal = MagicMock()
|
||||
mock__terminal.terminal_stdout_re = [re.compile(b'device#')]
|
||||
mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]
|
||||
|
|
Loading…
Reference in a new issue