diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index aa97e5d204e..6959ec31e60 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -302,6 +302,9 @@ class Connection(NetworkConnectionBase): logging.getLogger('paramiko').setLevel(logging.DEBUG) if self._network_os: + self._terminal = terminal_loader.get(self._network_os, self) + if not self._terminal: + raise AnsibleConnectionFailure('network os %s is not supported' % self._network_os) self.cliconf = cliconf_loader.get(self._network_os, self) if self.cliconf: @@ -391,10 +394,6 @@ class Connection(NetworkConnectionBase): self._ssh_shell = ssh.ssh.invoke_shell() self._ssh_shell.settimeout(self.get_option('persistent_command_timeout')) - self._terminal = terminal_loader.get(self._network_os, self) - if not self._terminal: - raise AnsibleConnectionFailure('network os %s is not supported' % self._network_os) - self.queue_message('vvvv', 'loaded terminal plugin for network_os %s' % self._network_os) terminal_initial_prompt = self.get_option('terminal_initial_prompt') or self._terminal.terminal_initial_prompt diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index bbb40676f8e..a625449a72e 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -34,19 +34,13 @@ from ansible.plugins.loader import connection_loader class TestConnectionClass(unittest.TestCase): - @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") - def test_network_cli__connect_error(self, mocked_super): - pc = PlayContext() - pc.network_os = 'ios' - conn = connection_loader.get('network_cli', pc, '/dev/null') - - conn.ssh = MagicMock() - conn.receive = MagicMock() - conn._network_os = 'does not exist' - - self.assertRaises(AnsibleConnectionFailure, conn._connect) - def test_network_cli__invalid_os(self): + pc = PlayContext() + pc.network_os = 'does not exist' + + self.assertRaises(AnsibleConnectionFailure, connection_loader.get, 'network_cli', pc, '/dev/null') + + def test_network_cli__no_os(self): pc = PlayContext() pc.network_os = None