Connection plugins network_cli and netconf (#32521)

* implements jsonrpc message passing for ansible-connection

* implements more generic mechanism for persistent connections
* starts persistent connection in task_executor if enabled and supported
* supports using network_cli as top level connection plugin
* enhances logging for persistent connection to stdout

* Update action plugins

* Fix Python3 RPC

* Fix Junos bytes<-->str issues

* supports using netconf as top level connection plugin

* Error message when running netconf on an unsupported platform
* Update tests

* Fix `authorize: yes` for `connection: local`

* Handle potentially JSON data in terminal

* Add clarifying detail if possible on ConnectionError
This commit is contained in:
Nathaniel Case 2017-11-09 15:04:40 -05:00 committed by GitHub
parent 897b31f249
commit 9c0275a879
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 722 additions and 798 deletions

View file

@ -1,23 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright: (c) 2017, Ansible Project
# (c) 2017, Ansible, Inc. <support@ansible.com> # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
########################################################
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
@ -36,91 +19,68 @@ import socket
import sys import sys
import time import time
import traceback import traceback
import datetime
import errno import errno
import json
from ansible import constants as C from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.six import PY3 from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import send_data, recv_data from ansible.module_utils.connection import send_data, recv_data
from ansible.module_utils.service import fork_process
from ansible.playbook.play_context import PlayContext from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import connection_loader from ansible.plugins.loader import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleError
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.jsonrpc import JsonRpcServer
def do_fork(): class ConnectionProcess(object):
''' '''
Does the required double fork for a daemon process. Based on The connection process wraps around a Connection object that manages
http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/ the connection to a remote device that persists over the playbook
''' '''
try: def __init__(self, fd, play_context, socket_path, original_path):
pid = os.fork()
if pid > 0:
return pid
# This is done as a 'good practice' for daemons, but we need to keep the cwd
# leaving it here as a note that we KNOW its good practice but are not doing it on purpose.
# os.chdir("/")
os.setsid()
os.umask(0)
try:
pid = os.fork()
if pid > 0:
sys.exit(0)
if C.DEFAULT_LOG_PATH != '':
out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else:
out_file = open('/dev/null', 'ab+')
err_file = open('/dev/null', 'ab+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
os.close(sys.stdin.fileno())
return pid
except OSError as e:
sys.exit(1)
except OSError as e:
sys.exit(1)
class Server():
def __init__(self, socket_path, play_context):
self.socket_path = socket_path
self.play_context = play_context self.play_context = play_context
self.socket_path = socket_path
self.original_path = original_path
display.display( self.fd = fd
'creating new control socket for host %s:%s as user %s' % self.exception = None
(play_context.remote_addr, play_context.port, play_context.remote_user),
log_only=True
)
display.display('control socket path is %s' % socket_path, log_only=True) self.srv = JsonRpcServer()
display.display('current working directory is %s' % os.getcwd(), log_only=True) self.sock = None
self._start_time = datetime.datetime.now() def start(self):
try:
messages = list()
result = {}
display.display("using connection plugin %s" % self.play_context.connection, log_only=True) messages.append('control socket path is %s' % self.socket_path)
self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin) # If this is a relative path (~ gets expanded later) then plug the
self.connection._connect() # key's path on to the directory we originally came from, so we can
# find it now that our cwd is /
if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
if not self.connection.connected: self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null')
raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr) self.connection._connect()
self.srv.register(self.connection)
messages.append('connection to remote device started successfully')
connection_time = datetime.datetime.now() - self._start_time self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True) self.sock.bind(self.socket_path)
self.sock.listen(1)
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) messages.append('local domain socket listeners started successfully')
self.socket.bind(self.socket_path) except Exception as exc:
self.socket.listen(1) result['error'] = to_text(exc)
display.display('local socket is set to listening', log_only=True) result['exception'] = traceback.format_exc()
finally:
result['messages'] = messages
self.fd.write(json.dumps(result))
self.fd.close()
def run(self): def run(self):
try: try:
@ -129,53 +89,36 @@ class Server():
signal.signal(signal.SIGTERM, self.handler) signal.signal(signal.SIGTERM, self.handler)
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT) signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
(s, addr) = self.socket.accept() self.exception = None
display.display('incoming request accepted on persistent socket', log_only=True) (s, addr) = self.sock.accept()
signal.alarm(0) signal.alarm(0)
signal.signal(signal.SIGALRM, self.command_timeout)
while True: while True:
data = recv_data(s) data = recv_data(s)
if not data: if not data:
break break
signal.signal(signal.SIGALRM, self.command_timeout) signal.alarm(self.connection._play_context.timeout)
signal.alarm(self.play_context.timeout) resp = self.srv.handle_request(data)
op = to_text(data.split(b':')[0])
display.display('socket operation is %s' % op, log_only=True)
method = getattr(self, 'do_%s' % op, None)
rc = 255
stdout = stderr = ''
if not method:
stderr = 'Invalid action specified'
else:
rc, stdout, stderr = method(data)
signal.alarm(0) signal.alarm(0)
display.display('socket operation completed with rc %s' % rc, log_only=True) send_data(s, to_bytes(resp))
send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
s.close() s.close()
except Exception as e: except Exception as e:
# socket.accept() will raise EINTR if the socket.close() is called # socket.accept() will raise EINTR if the socket.close() is called
if e.errno != errno.EINTR: if hasattr(e, 'errno'):
display.display(traceback.format_exc(), log_only=True) if e.errno != errno.EINTR:
self.exception = traceback.format_exc()
else:
self.exception = traceback.format_exc()
finally: finally:
# when done, close the connection properly and cleanup # when done, close the connection properly and cleanup
# the socket file so it can be recreated # the socket file so it can be recreated
self.shutdown() self.shutdown()
end_time = datetime.datetime.now()
delta = end_time - self._start_time
display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True)
def connect_timeout(self, signum, frame): def connect_timeout(self, signum, frame):
display.display('persistent connection idle timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True) display.display('persistent connection idle timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True)
@ -190,25 +133,25 @@ class Server():
self.shutdown() self.shutdown()
def shutdown(self): def shutdown(self):
display.display('shutdown persistent connection requested', log_only=True) """ Shuts down the local domain socket
"""
if not os.path.exists(self.socket_path): if not os.path.exists(self.socket_path):
display.display('persistent connection is not active', log_only=True)
return return
try: try:
if self.socket: if self.sock:
display.display('closing local listener', log_only=True) self.sock.close()
self.socket.close()
if self.connection: if self.connection:
display.display('closing the connection', log_only=True)
self.connection.close() self.connection.close()
except: except:
pass pass
finally: finally:
if os.path.exists(self.socket_path): if os.path.exists(self.socket_path):
display.display('removing the local control socket', log_only=True)
os.remove(self.socket_path) os.remove(self.socket_path)
setattr(self.connection, '_socket_path', None)
setattr(self.connection, '_connected', False)
display.display('shutdown complete', log_only=True) display.display('shutdown complete', log_only=True)
@ -262,6 +205,13 @@ def communicate(sock, data):
def main(): def main():
""" Called to initiate the connect to the remote device
"""
rc = 0
result = {}
messages = list()
socket_path = None
# Need stdin as a byte stream # Need stdin as a byte stream
if PY3: if PY3:
stdin = sys.stdin.buffer stdin = sys.stdin.buffer
@ -270,116 +220,91 @@ def main():
try: try:
# read the play context data via stdin, which means depickling it # read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
cur_line = stdin.readline() cur_line = stdin.readline()
init_data = b'' init_data = b''
while cur_line.strip() != b'#END_INIT#': while cur_line.strip() != b'#END_INIT#':
if cur_line == b'': if cur_line == b'':
raise Exception("EOF found before init data was complete") raise Exception("EOF found before init data was complete")
init_data += cur_line init_data += cur_line
cur_line = stdin.readline() cur_line = stdin.readline()
if PY3: if PY3:
pc_data = cPickle.loads(init_data, encoding='bytes') pc_data = cPickle.loads(init_data, encoding='bytes')
else: else:
pc_data = cPickle.loads(init_data) pc_data = cPickle.loads(init_data)
pc = PlayContext() play_context = PlayContext()
pc.deserialize(pc_data) play_context.deserialize(pc_data)
except Exception as e: except Exception as e:
# FIXME: better error message/handling/logging rc = 1
sys.stderr.write(traceback.format_exc()) result.update({
sys.exit("FAIL: %s" % e) 'error': to_text(e),
'exception': traceback.format_exc()
})
ssh = connection_loader.get('ssh', class_only=True) if rc == 0:
cp = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user, pc.connection) ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection)
# create the persistent connection dir if need be and create the paths # create the persistent connection dir if need be and create the paths
# which we will be using later # which we will be using later
tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
makedirs_safe(tmp_path) makedirs_safe(tmp_path)
lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
socket_path = unfrackpath(cp % dict(directory=tmp_path))
# if the socket file doesn't exist, spin up the daemon process lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600) socket_path = unfrackpath(cp % dict(directory=tmp_path))
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
if not os.path.exists(socket_path): # if the socket file doesn't exist, spin up the daemon process
pid = do_fork() lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600)
if pid == 0: fcntl.lockf(lock_fd, fcntl.LOCK_EX)
rc = 0
try:
server = Server(socket_path, pc)
except AnsibleConnectionFailure as exc:
display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True)
display.display(str(exc), log_only=True)
rc = 1
except Exception as exc:
display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True)
display.display(traceback.format_exc(), log_only=True)
rc = 1
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
if rc == 0:
server.run()
sys.exit(rc)
else:
display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True)
fcntl.lockf(lock_fd, fcntl.LOCK_UN) if not os.path.exists(socket_path):
os.close(lock_fd) messages.append('local domain socket does not exist, starting it')
original_path = os.getcwd()
r, w = os.pipe()
pid = fork_process()
timeout = pc.timeout if pid == 0:
while bool(timeout): try:
if os.path.exists(socket_path): os.close(r)
display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr) wfd = os.fdopen(w, 'w')
break process = ConnectionProcess(wfd, play_context, socket_path, original_path)
time.sleep(1) process.start()
timeout -= 1 except Exception as exc:
else: messages.append(traceback.format_exc())
raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr) rc = 1
# now connect to the daemon process fcntl.lockf(lock_fd, fcntl.LOCK_UN)
# FIXME: if the socket file existed but the daemonized process was killed, os.close(lock_fd)
# the connection will timeout here. Need to make this more resilient.
while True:
data = stdin.readline()
if data == b'':
break
if data.strip() == b'':
continue
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if rc == 0:
process.run()
sys.exit(rc)
else:
os.close(w)
rfd = os.fdopen(r, 'r')
data = json.loads(rfd.read())
messages.extend(data.pop('messages'))
result.update(data)
connect_retry_timeout = C.PERSISTENT_CONNECT_RETRY_TIMEOUT
while bool(connect_retry_timeout):
try:
sock.connect(socket_path)
break
except socket.error:
time.sleep(1)
connect_retry_timeout -= 1
else: else:
display.display('connect retry timeout expired, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True) messages.append('found existing local domain socket, using it!')
display.display('persistent_connect_retry_timeout is %s secs' % (C.PERSISTENT_CONNECT_RETRY_TIMEOUT), pc.remote_addr, pc.remote_user, log_only=True)
sys.stderr.write('failed to connect to control socket')
sys.exit(255)
# send the play_context back into the connection so the connection result.update({
# can handle any privilege escalation activities 'messages': messages,
pc_data = b'CONTEXT: %s' % init_data 'socket_path': socket_path
communicate(sock, pc_data) })
rc, stdout, stderr = communicate(sock, data.strip()) if 'exception' in result:
rc = 1
sys.stdout.write(to_native(stdout)) sys.stderr.write(json.dumps(result))
sys.stderr.write(to_native(stderr)) else:
rc = 0
sock.close() sys.stdout.write(json.dumps(result))
break
sys.exit(rc) sys.exit(rc)

View file

@ -1,31 +1,21 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> # (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# # (c) 2017 Ansible Project
# This file is part of Ansible # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import os
import pty
import time import time
import json
import subprocess
import traceback import traceback
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip
from ansible.executor.task_result import TaskResult from ansible.executor.task_result import TaskResult
from ansible.module_utils.six import iteritems, string_types, binary_type from ansible.module_utils.six import iteritems, string_types, binary_type
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task from ansible.playbook.task import Task
@ -490,6 +480,8 @@ class TaskExecutor:
not getattr(self._connection, 'connected', False) or not getattr(self._connection, 'connected', False) or
self._play_context.remote_addr != self._connection._play_context.remote_addr): self._play_context.remote_addr != self._connection._play_context.remote_addr):
self._connection = self._get_connection(variables=variables, templar=templar) self._connection = self._get_connection(variables=variables, templar=templar)
if getattr(self._connection, '_socket_path'):
variables['ansible_socket'] = self._connection._socket_path
# only template the vars if the connection actually implements set_host_overrides # only template the vars if the connection actually implements set_host_overrides
# NB: this is expensive, and should be removed once connection-specific vars are being handled by play_context # NB: this is expensive, and should be removed once connection-specific vars are being handled by play_context
sho_impl = getattr(type(self._connection), 'set_host_overrides', None) sho_impl = getattr(type(self._connection), 'set_host_overrides', None)
@ -736,12 +728,7 @@ class TaskExecutor:
if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"): if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"):
variables[i] = delegated_vars[i] variables[i] = delegated_vars[i]
# if using persistent paramiko connections (or the action has set the FORCE_PERSISTENT_CONNECTION attribute to True), conn_type = self._play_context.connection
# then we use the persistent connection plugion. Otherwise load the requested connection plugin
if C.USE_PERSISTENT_CONNECTIONS or getattr(self, 'FORCE_PERSISTENT_CONNECTION', False):
conn_type = 'persistent'
else:
conn_type = self._play_context.connection
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin) connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
if not connection: if not connection:
@ -749,6 +736,13 @@ class TaskExecutor:
self._play_context.set_options_from_plugin(connection) self._play_context.set_options_from_plugin(connection)
if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)):
display.vvvv('attempting to start connection', host=self._play_context.remote_addr)
display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr)
socket_path = self._start_connection()
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(connection, '_socket_path', socket_path)
return connection return connection
def _get_action_handler(self, connection, templar): def _get_action_handler(self, connection, templar):
@ -780,3 +774,42 @@ class TaskExecutor:
raise AnsibleError("the handler '%s' was not found" % handler_name) raise AnsibleError("the handler '%s' was not found" % handler_name)
return handler return handler
def _start_connection(self):
'''
Starts the persistent connection
'''
master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)
# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')
(stdout, stderr) = p.communicate()
stdin.close()
if p.returncode == 0:
result = json.loads(stdout)
else:
result = json.loads(stderr)
if 'messages' in result:
for msg in result.get('messages'):
display.vvvv('%s' % msg, host=self._play_context.remote_addr)
if 'error' in result:
if self._play_context.verbosity > 2:
msg = "The full traceback is:\n" + result['exception']
display.display(result['exception'], color=C.COLOR_ERROR)
raise AnsibleError(result['error'])
return result['socket_path']

View file

@ -27,6 +27,7 @@
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os import os
import json
import socket import socket
import struct import struct
import traceback import traceback
@ -35,6 +36,7 @@ import uuid
from functools import partial from functools import partial
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.six import iteritems
def send_data(s, data): def send_data(s, data):
@ -61,23 +63,14 @@ def recv_data(s):
def exec_command(module, command): def exec_command(module, command):
connection = Connection(module._socket_path)
try: try:
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) out = connection.exec_command(command)
sf.connect(module._socket_path) except ConnectionError as exc:
code = getattr(exc, 'code', 1)
data = "EXEC: %s" % command message = getattr(exc, 'err', exc)
send_data(sf, to_bytes(data.strip())) return code, '', to_text(message, errors='surrogate_then_replace')
return 0, out, ''
rc = int(recv_data(sf), 10)
stdout = recv_data(sf)
stderr = recv_data(sf)
except socket.error as e:
sf.close()
module.fail_json(msg='unable to connect to socket', err=to_native(e), exception=traceback.format_exc())
sf.close()
return rc, to_native(stdout, errors='surrogate_or_strict'), to_native(stderr, errors='surrogate_or_strict')
def request_builder(method, *args, **kwargs): def request_builder(method, *args, **kwargs):
@ -91,10 +84,19 @@ def request_builder(method, *args, **kwargs):
return req return req
class ConnectionError(Exception):
def __init__(self, message, *args, **kwargs):
super(ConnectionError, self).__init__(message)
for k, v in iteritems(kwargs):
setattr(self, k, v)
class Connection: class Connection:
def __init__(self, module): def __init__(self, socket_path):
self._module = module assert socket_path is not None, 'socket_path must be a value'
self.socket_path = socket_path
def __getattr__(self, name): def __getattr__(self, name):
try: try:
@ -116,30 +118,40 @@ class Connection:
req = request_builder(name, *args, **kwargs) req = request_builder(name, *args, **kwargs)
reqid = req['id'] reqid = req['id']
if not self._module._socket_path: if not os.path.exists(self.socket_path):
self._module.fail_json(msg='provider support not available for this host') raise ConnectionError('socket_path does not exist or cannot be found')
if not os.path.exists(self._module._socket_path):
self._module.fail_json(msg='provider socket does not exist, is the provider running?')
try: try:
data = self._module.jsonify(req) data = json.dumps(req)
rc, out, err = exec_command(self._module, data) out = self.send(data)
response = json.loads(out)
except socket.error as e: except socket.error as e:
self._module.fail_json(msg='unable to connect to socket', err=to_native(e), raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc())
exception=traceback.format_exc())
try:
response = self._module.from_json(to_text(out, errors='surrogate_then_replace'))
except ValueError as exc:
self._module.fail_json(msg=to_text(exc, errors='surrogate_then_replace'))
if response['id'] != reqid: if response['id'] != reqid:
self._module.fail_json(msg='invalid id received') raise ConnectionError('invalid json-rpc id received')
if 'error' in response: if 'error' in response:
msg = response['error'].get('data') or response['error']['message'] err = response.get('error')
self._module.fail_json(msg=to_text(msg, errors='surrogate_then_replace')) msg = err.get('data') or err['message']
code = err['code']
raise ConnectionError(to_text(msg, errors='surrogate_then_replace'), code=code)
return response['result'] return response['result']
def send(self, data):
try:
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sf.connect(self.socket_path)
send_data(sf, to_bytes(data))
response = recv_data(sf)
except socket.error as e:
sf.close()
raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc())
sf.close()
return to_text(response, errors='surrogate_or_strict')

View file

@ -27,6 +27,7 @@
# #
from contextlib import contextmanager from contextlib import contextmanager
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.connection import exec_command from ansible.module_utils.connection import exec_command
try: try:
@ -38,7 +39,7 @@ NS_MAP = {'nc': "urn:ietf:params:xml:ns:netconf:base:1.0"}
def send_request(module, obj, check_rc=True, ignore_warning=True): def send_request(module, obj, check_rc=True, ignore_warning=True):
request = tostring(obj) request = to_text(tostring(obj), errors='surrogate_or_strict')
rc, out, err = exec_command(module, request) rc, out, err = exec_command(module, request)
if rc != 0 and check_rc: if rc != 0 and check_rc:
error_root = fromstring(err) error_root = fromstring(err)
@ -59,7 +60,7 @@ def send_request(module, obj, check_rc=True, ignore_warning=True):
else: else:
module.fail_json(msg=str(err)) module.fail_json(msg=str(err))
return warnings return warnings
return fromstring(out) return fromstring(to_bytes(out, errors='surrogate_or_strict'))
def children(root, iterable): def children(root, iterable):

View file

@ -91,33 +91,14 @@ def fail_if_missing(module, found, service, msg=''):
module.fail_json(msg='Could not find the requested service %s: %s' % (service, msg)) module.fail_json(msg='Could not find the requested service %s: %s' % (service, msg))
def daemonize(module, cmd): def fork_process():
''' '''
Execute a command while detaching as a daemon, returns rc, stdout, and stderr. This function performs the double fork process to detach from the
parent process and execute.
:arg module: is an AnsibleModule object, used for it's utility methods
:arg cmd: is a list or string representing the command and options to run
This is complex because daemonization is hard for people.
What we do is daemonize a part of this module, the daemon runs the command,
picks up the return code and output, and returns it to the main process.
''' '''
pid = os.fork()
# init some vars
chunk = 4096 # FIXME: pass in as arg?
errors = 'surrogate_or_strict'
# start it!
try:
pipe = os.pipe()
pid = os.fork()
except OSError:
module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc())
# we don't do any locking as this should be a unique module/process
if pid == 0: if pid == 0:
os.close(pipe[0])
# Set stdin/stdout/stderr to /dev/null # Set stdin/stdout/stderr to /dev/null
fd = os.open(os.devnull, os.O_RDWR) fd = os.open(os.devnull, os.O_RDWR)
@ -140,7 +121,7 @@ def daemonize(module, cmd):
# get new process session and detach # get new process session and detach
sid = os.setsid() sid = os.setsid()
if sid == -1: if sid == -1:
module.fail_json(msg="Unable to detach session while daemonizing") raise Exception("Unable to detach session while daemonizing")
# avoid possible problems with cwd being removed # avoid possible problems with cwd being removed
os.chdir("/") os.chdir("/")
@ -149,6 +130,38 @@ def daemonize(module, cmd):
if pid > 0: if pid > 0:
os._exit(0) os._exit(0)
return pid
def daemonize(module, cmd):
'''
Execute a command while detaching as a daemon, returns rc, stdout, and stderr.
:arg module: is an AnsibleModule object, used for it's utility methods
:arg cmd: is a list or string representing the command and options to run
This is complex because daemonization is hard for people.
What we do is daemonize a part of this module, the daemon runs the command,
picks up the return code and output, and returns it to the main process.
'''
# init some vars
chunk = 4096 # FIXME: pass in as arg?
errors = 'surrogate_or_strict'
# start it!
try:
pipe = os.pipe()
pid = fork_process()
except OSError:
module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc())
except Exception as exc:
module.fail_json(msg=to_text(exc), exception=traceback.format_exc())
# we don't do any locking as this should be a unique module/process
if pid == 0:
os.close(pipe[0])
# if command is string deal with py2 vs py3 conversions for shlex # if command is string deal with py2 vs py3 conversions for shlex
if not isinstance(cmd, list): if not isinstance(cmd, list):
if PY2: if PY2:

View file

@ -427,6 +427,8 @@ class PlayContext(Base):
# if the final connection type is local, reset the remote_user value to that of the currently logged in user # if the final connection type is local, reset the remote_user value to that of the currently logged in user
# this ensures any become settings are obeyed correctly # this ensures any become settings are obeyed correctly
# we store original in 'connection_user' for use of network/other modules that fallback to it as login user # we store original in 'connection_user' for use of network/other modules that fallback to it as login user
# connection_user to be deprecated once connection=local is removed for
# network modules
if new_info.connection == 'local': if new_info.connection == 'local':
if not new_info.connection_user: if not new_info.connection_user:
new_info.connection_user = new_info.remote_user new_info.connection_user = new_info.remote_user

View file

@ -36,6 +36,7 @@ from ansible.module_utils.json_utils import _filter_non_json_lines
from ansible.module_utils.six import binary_type, string_types, text_type, iteritems, with_metaclass from ansible.module_utils.six import binary_type, string_types, text_type, iteritems, with_metaclass
from ansible.module_utils.six.moves import shlex_quote from ansible.module_utils.six.moves import shlex_quote
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.connection import Connection
from ansible.parsing.utils.jsonify import jsonify from ansible.parsing.utils.jsonify import jsonify
from ansible.release import __version__ from ansible.release import __version__
from ansible.utils.unsafe_proxy import wrap_var from ansible.utils.unsafe_proxy import wrap_var
@ -604,7 +605,9 @@ class ActionBase(with_metaclass(ABCMeta, object)):
module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS
# give the module the socket for persistent connections # give the module the socket for persistent connections
module_args['_ansible_socket'] = task_vars.get('ansible_socket') module_args['_ansible_socket'] = getattr(self._connection, 'socket_path')
if not module_args['_ansible_socket']:
module_args['_ansible_socket'] = task_vars.get('ansible_socket')
# make sure all commands use the designated shell executable # make sure all commands use the designated shell executable
module_args['_ansible_shell_executable'] = self._play_context.executable module_args['_ansible_shell_executable'] = self._play_context.executable
@ -818,7 +821,8 @@ class ActionBase(with_metaclass(ABCMeta, object)):
same_user = self._play_context.become_user == self._play_context.remote_user same_user = self._play_context.become_user == self._play_context.remote_user
if sudoable and self._play_context.become and (allow_same_user or not same_user): if sudoable and self._play_context.become and (allow_same_user or not same_user):
display.debug("_low_level_execute_command(): using become for this command") display.debug("_low_level_execute_command(): using become for this command")
cmd = self._play_context.make_become_cmd(cmd, executable=executable) if self._connection.transport != 'network_cli' and self._play_context.become_method != 'enable':
cmd = self._play_context.make_become_cmd(cmd, executable=executable)
if self._connection.allow_executable: if self._connection.allow_executable:
if executable is None: if executable is None:

View file

@ -40,47 +40,35 @@ class ActionModule(_ActionModule):
provider = load_provider(eos_provider_spec, self._task.args) provider = load_provider(eos_provider_spec, self._task.args)
transport = provider['transport'] or 'cli' transport = provider['transport'] or 'cli'
if self._play_context.connection != 'local' and transport == 'cli':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr) display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr)
if transport == 'cli': if transport == 'cli':
pc = copy.deepcopy(self._play_context) if self._play_context.connection == 'local':
pc.connection = 'network_cli' pc = copy.deepcopy(self._play_context)
pc.network_os = 'eos' pc.connection = 'network_cli'
pc.remote_addr = provider['host'] or self._play_context.remote_addr pc.network_os = 'eos'
pc.port = int(provider['port'] or self._play_context.port or 22) pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.remote_user = provider['username'] or self._play_context.connection_user pc.port = int(provider['port'] or self._play_context.port or 22)
pc.password = provider['password'] or self._play_context.password pc.remote_user = provider['username'] or self._play_context.connection_user
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file pc.password = provider['password'] or self._play_context.password
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.become = provider['authorize'] or False pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc.become_pass = provider['auth_pass'] pc.become = provider['authorize'] or False
if pc.become:
pc.become_method = 'enable'
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run() socket_path = connection.run()
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path: if not socket_path:
return {'failed': True, return {'failed': True,
'msg': 'unable to open shell. Please see: ' + 'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be task_vars['ansible_socket'] = socket_path
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while '(config' in str(out):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
else: else:
provider['transport'] = 'eapi' provider['transport'] = 'eapi'

View file

@ -38,50 +38,38 @@ class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local': if self._play_context.connection == 'local':
return dict( provider = load_provider(ios_provider_spec, self._task.args)
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
provider = load_provider(ios_provider_spec, self._task.args) pc = copy.deepcopy(self._play_context)
pc.connection = 'network_cli'
pc.network_os = 'ios'
pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.port = int(provider['port'] or self._play_context.port or 22)
pc.remote_user = provider['username'] or self._play_context.connection_user
pc.password = provider['password'] or self._play_context.password
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc.become = provider['authorize'] or False
if pc.become:
pc.become_method = 'enable'
pc.become_pass = provider['auth_pass']
pc = copy.deepcopy(self._play_context) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
pc.connection = 'network_cli' connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
pc.network_os = 'ios'
pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.port = int(provider['port'] or self._play_context.port or 22)
pc.remote_user = provider['username'] or self._play_context.connection_user
pc.password = provider['password'] or self._play_context.password
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc.become = provider['authorize'] or False
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) socket_path = connection.run()
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
socket_path = connection.run() task_vars['ansible_socket'] = socket_path
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be if self._play_context.become_method == 'enable':
# enable mode and not config module self._play_context.become = False
rc, out, err = connection.exec_command('prompt()') self._play_context.become_method = None
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
if self._play_context.become_method == 'enable':
self._play_context.become = False
self._play_context.become_method = None
result = super(ActionModule, self).run(tmp, task_vars) result = super(ActionModule, self).run(tmp, task_vars)
return result return result

View file

@ -38,43 +38,29 @@ class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local': if self._play_context.connection == 'local':
return dict( provider = load_provider(iosxr_provider_spec, self._task.args)
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
provider = load_provider(iosxr_provider_spec, self._task.args) pc = copy.deepcopy(self._play_context)
pc.connection = 'network_cli'
pc.network_os = 'iosxr'
pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.port = int(provider['port'] or self._play_context.port or 22)
pc.remote_user = provider['username'] or self._play_context.connection_user
pc.password = provider['password'] or self._play_context.password
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc = copy.deepcopy(self._play_context) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
pc.connection = 'network_cli' connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
pc.network_os = 'iosxr'
pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.port = int(provider['port'] or self._play_context.port or 22)
pc.remote_user = provider['username'] or self._play_context.connection_user
pc.password = provider['password'] or self._play_context.password
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) socket_path = connection.run()
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
socket_path = connection.run() task_vars['ansible_socket'] = socket_path
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
result = super(ActionModule, self).run(tmp, task_vars) result = super(ActionModule, self).run(tmp, task_vars)
return result return result

View file

@ -38,14 +38,6 @@ except ImportError:
class ActionModule(_ActionModule): class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
module = module_loader._load_module_source(self._task.action, module_loader.find_plugin(self._task.action)) module = module_loader._load_module_source(self._task.action, module_loader.find_plugin(self._task.action))
if not getattr(module, 'USE_PERSISTENT_CONNECTION', False): if not getattr(module, 'USE_PERSISTENT_CONNECTION', False):
@ -72,25 +64,27 @@ class ActionModule(_ActionModule):
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run() if self._play_context.connection == 'local':
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
if pc.connection == 'network_cli': socket_path = connection.run()
# make sure we are in the right cli context which should be display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
# enable mode and not config module if not socket_path:
rc, out, err = connection.exec_command('prompt()') return {'failed': True,
while str(out).strip().endswith(')#'): 'msg': 'unable to open shell. Please see: ' +
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
connection.exec_command('exit')
if pc.connection == 'network_cli':
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()') rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path task_vars['ansible_socket'] = socket_path
result = super(ActionModule, self).run(tmp, task_vars) result = super(ActionModule, self).run(tmp, task_vars)
return result return result

View file

@ -37,13 +37,6 @@ except ImportError:
class ActionModule(ActionBase): class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
play_context = copy.deepcopy(self._play_context) play_context = copy.deepcopy(self._play_context)
play_context.network_os = self._get_network_os(task_vars) play_context.network_os = self._get_network_os(task_vars)
@ -74,8 +67,9 @@ class ActionModule(ActionBase):
play_context.become = self.provider['authorize'] or False play_context.become = self.provider['authorize'] or False
play_context.become_pass = self.provider['auth_pass'] play_context.become_pass = self.provider['auth_pass']
socket_path = self._start_connection(play_context) if self._play_context.connection == 'local':
task_vars['ansible_socket'] = socket_path socket_path = self._start_connection(play_context)
task_vars['ansible_socket'] = socket_path
if 'fail_on_missing_module' not in self._task.args: if 'fail_on_missing_module' not in self._task.args:
self._task.args['fail_on_missing_module'] = False self._task.args['fail_on_missing_module'] = False

View file

@ -40,44 +40,31 @@ class ActionModule(_ActionModule):
provider = load_provider(nxos_provider_spec, self._task.args) provider = load_provider(nxos_provider_spec, self._task.args)
transport = provider['transport'] or 'cli' transport = provider['transport'] or 'cli'
if self._play_context.connection != 'local' and transport == 'cli':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr) display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr)
if transport == 'cli': if transport == 'cli':
pc = copy.deepcopy(self._play_context) if self._play_context.connection == 'local':
pc.connection = 'network_cli' pc = copy.deepcopy(self._play_context)
pc.network_os = 'nxos' pc.connection = 'network_cli'
pc.remote_addr = provider['host'] or self._play_context.remote_addr pc.network_os = 'nxos'
pc.port = int(provider['port'] or self._play_context.port or 22) pc.remote_addr = provider['host'] or self._play_context.remote_addr
pc.remote_user = provider['username'] or self._play_context.connection_user pc.port = int(provider['port'] or self._play_context.port or 22)
pc.password = provider['password'] or self._play_context.password pc.remote_user = provider['username'] or self._play_context.connection_user
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file pc.password = provider['password'] or self._play_context.password
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
socket_path = connection.run() connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be socket_path = connection.run()
# enable mode and not config module display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
rc, out, err = connection.exec_command('prompt()') if not socket_path:
while str(out).strip().endswith(')#'): return {'failed': True,
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) 'msg': 'unable to open shell. Please see: ' +
connection.exec_command('exit') 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path task_vars['ansible_socket'] = socket_path
else: else:
provider['transport'] = 'nxapi' provider['transport'] = 'nxapi'

View file

@ -37,13 +37,6 @@ except ImportError:
class ActionModule(_ActionModule): class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None): def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
provider = load_provider(vyos_provider_spec, self._task.args) provider = load_provider(vyos_provider_spec, self._task.args)
pc = copy.deepcopy(self._play_context) pc = copy.deepcopy(self._play_context)
@ -57,24 +50,18 @@ class ActionModule(_ActionModule):
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run() if self._play_context.connection == 'local':
display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
if not socket_path:
return {'failed': True,
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be socket_path = connection.run()
# enable mode and not config module display.vvvv('socket_path: %s' % socket_path, pc.remote_addr)
rc, out, err = connection.exec_command('prompt()') if not socket_path:
while str(out).strip().endswith('#'): return {'failed': True,
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) 'msg': 'unable to open shell. Please see: ' +
connection.exec_command('exit') 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path task_vars['ansible_socket'] = socket_path
result = super(ActionModule, self).run(tmp, task_vars) result = super(ActionModule, self).run(tmp, task_vars)
return result return result

View file

@ -1,21 +1,7 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com> # (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
# # (c) 2017, Peter Sprygada <psprygad@redhat.com>
# This file is part of Ansible # (c) 2017 Ansible Project
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
@ -69,6 +55,11 @@ class ConnectionBase(AnsiblePlugin):
module_implementation_preferences = ('',) module_implementation_preferences = ('',)
allow_executable = True allow_executable = True
# the following control whether or not the connection supports the
# persistent connection framework or not
supports_persistence = False
force_persistence = False
def __init__(self, play_context, new_stdin, *args, **kwargs): def __init__(self, play_context, new_stdin, *args, **kwargs):
super(ConnectionBase, self).__init__() super(ConnectionBase, self).__init__()
@ -88,6 +79,8 @@ class ConnectionBase(AnsiblePlugin):
self.prompt = None self.prompt = None
self._connected = False self._connected = False
self._socket_path = None
# load the shell plugin for this action/connection # load the shell plugin for this action/connection
if play_context.shell: if play_context.shell:
shell_type = play_context.shell shell_type = play_context.shell
@ -110,6 +103,11 @@ class ConnectionBase(AnsiblePlugin):
'''Read-only property holding whether the connection to the remote host is active or closed.''' '''Read-only property holding whether the connection to the remote host is active or closed.'''
return self._connected return self._connected
@property
def socket_path(self):
'''Read-only property holding the connection socket path for this remote host'''
return self._socket_path
def _become_method_supported(self): def _become_method_supported(self):
''' Checks if the current class supports this privilege escalation method ''' ''' Checks if the current class supports this privilege escalation method '''

View file

@ -71,15 +71,14 @@ DOCUMENTATION = """
import os import os
import logging import logging
import json
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure, AnsibleError from ansible.errors import AnsibleConnectionFailure, AnsibleError
from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE
from ansible.plugins.loader import netconf_loader from ansible.plugins.loader import netconf_loader
from ansible.plugins.connection import ConnectionBase, ensure_connect from ansible.plugins.connection import ConnectionBase, ensure_connect
from ansible.utils.jsonrpc import Rpc from ansible.plugins.connection.local import Connection as LocalConnection
try: try:
from ncclient import manager from ncclient import manager
@ -98,11 +97,12 @@ except ImportError:
logging.getLogger('ncclient').setLevel(logging.INFO) logging.getLogger('ncclient').setLevel(logging.INFO)
class Connection(Rpc, ConnectionBase): class Connection(ConnectionBase):
"""NetConf connections""" """NetConf connections"""
transport = 'netconf' transport = 'netconf'
has_pipelining = False has_pipelining = False
force_persistence = True
def __init__(self, play_context, new_stdin, *args, **kwargs): def __init__(self, play_context, new_stdin, *args, **kwargs):
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
@ -113,18 +113,50 @@ class Connection(Rpc, ConnectionBase):
self._manager = None self._manager = None
self._connected = False self._connected = False
self._local = LocalConnection(play_context, new_stdin, *args, **kwargs)
def exec_command(self, request, in_data=None, sudoable=True):
"""Sends the request to the node and returns the reply
The method accepts two forms of request. The first form is as a byte
string that represents xml string be send over netconf session.
The second form is a json-rpc (2.0) byte string.
"""
if self._manager:
# to_ele operates on native strings
request = to_ele(to_native(request, errors='surrogate_or_strict'))
if request is None:
return 'unable to parse request'
try:
reply = self._manager.rpc(request)
except RPCError as exc:
return to_xml(exc.xml)
return reply.data_xml
else:
return self._local.exec_command(request, in_data, sudoable)
def put_file(self, in_path, out_path):
"""Transfer a file from local to remote"""
return self._local.put_file(in_path, out_path)
def fetch_file(self, in_path, out_path):
"""Fetch a file from remote to local"""
return self._local.fetch_file(in_path, out_path)
def _connect(self): def _connect(self):
super(Connection, self)._connect() super(Connection, self)._connect()
display.display('ssh connection done, stating ncclient', log_only=True) display.display('ssh connection done, starting ncclient', log_only=True)
self.allow_agent = True allow_agent = True
if self._play_context.password is not None: if self._play_context.password is not None:
self.allow_agent = False allow_agent = False
self.key_filename = None key_filename = None
if self._play_context.private_key_file: if self._play_context.private_key_file:
self.key_filename = os.path.expanduser(self._play_context.private_key_file) key_filename = os.path.expanduser(self._play_context.private_key_file)
network_os = self._play_context.network_os network_os = self._play_context.network_os
@ -149,16 +181,18 @@ class Connection(Rpc, ConnectionBase):
port=self._play_context.port or 830, port=self._play_context.port or 830,
username=self._play_context.remote_user, username=self._play_context.remote_user,
password=self._play_context.password, password=self._play_context.password,
key_filename=str(self.key_filename), key_filename=str(key_filename),
hostkey_verify=C.HOST_KEY_CHECKING, hostkey_verify=C.HOST_KEY_CHECKING,
look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS, look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS,
allow_agent=self.allow_agent, allow_agent=allow_agent,
timeout=self._play_context.timeout, timeout=self._play_context.timeout,
device_params={'name': network_os}, device_params={'name': network_os},
ssh_config=ssh_config ssh_config=ssh_config
) )
except SSHUnknownHostError as exc: except SSHUnknownHostError as exc:
raise AnsibleConnectionFailure(str(exc)) raise AnsibleConnectionFailure(str(exc))
except ImportError as exc:
raise AnsibleError("connection=netconf is not supported on {0}".format(network_os))
if not self._manager.connected: if not self._manager.connected:
return 1, b'', b'not connected' return 1, b'', b'not connected'
@ -169,7 +203,6 @@ class Connection(Rpc, ConnectionBase):
self._netconf = netconf_loader.get(network_os, self) self._netconf = netconf_loader.get(network_os, self)
if self._netconf: if self._netconf:
self._rpc.add(self._netconf)
display.display('loaded netconf plugin for network_os %s' % network_os, log_only=True) display.display('loaded netconf plugin for network_os %s' % network_os, log_only=True)
else: else:
display.display('unable to load netconf for network_os %s' % network_os) display.display('unable to load netconf for network_os %s' % network_os)
@ -181,46 +214,3 @@ class Connection(Rpc, ConnectionBase):
self._manager.close_session() self._manager.close_session()
self._connected = False self._connected = False
super(Connection, self).close() super(Connection, self).close()
@ensure_connect
def exec_command(self, request):
"""Sends the request to the node and returns the reply
The method accepts two forms of request. The first form is as a byte
string that represents xml string be send over netconf session.
The second form is a json-rpc (2.0) byte string.
"""
try:
obj = json.loads(to_text(request, errors='surrogate_or_strict'))
if 'jsonrpc' in obj:
if self._netconf:
out = self._exec_rpc(obj)
else:
out = self.internal_error("netconf plugin is not supported for network_os %s" % self._play_context.network_os)
return 0, to_bytes(out, errors='surrogate_or_strict'), b''
else:
err = self.invalid_request(obj)
return 1, b'', to_bytes(err, errors='surrogate_or_strict')
except (ValueError, TypeError):
# to_ele operates on native strings
request = to_native(request, errors='surrogate_or_strict')
req = to_ele(request)
if req is None:
return 1, b'', b'unable to parse request'
try:
reply = self._manager.rpc(req)
except RPCError as exc:
return 1, b'', to_bytes(to_xml(exc.xml), errors='surrogate_or_strict')
return 0, to_bytes(reply.data_xml, errors='surrogate_or_strict'), b''
def put_file(self, in_path, out_path):
"""Transfer a file from local to remote"""
pass
def fetch_file(self, in_path, out_path):
"""Fetch a file from remote to local"""
pass

View file

@ -47,6 +47,7 @@ DOCUMENTATION = """
import json import json
import logging import logging
import re import re
import os
import signal import signal
import socket import socket
import traceback import traceback
@ -57,9 +58,11 @@ from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six import BytesIO, binary_type from ansible.module_utils.six import BytesIO, binary_type
from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins.loader import cliconf_loader, terminal_loader from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection from ansible.plugins.connection import ConnectionBase
from ansible.utils.jsonrpc import Rpc from ansible.plugins.connection.local import Connection as LocalConnection
from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoSshConnection
from ansible.utils.path import unfrackpath, makedirs_safe
try: try:
from __main__ import display from __main__ import display
@ -68,31 +71,73 @@ except ImportError:
display = Display() display = Display()
class Connection(Rpc, _Connection): class Connection(ConnectionBase):
''' CLI (shell) SSH connections on Paramiko ''' ''' CLI (shell) SSH connections on Paramiko '''
transport = 'network_cli' transport = 'network_cli'
has_pipelining = True has_pipelining = True
force_persistence = True
def __init__(self, play_context, new_stdin, *args, **kwargs): def __init__(self, play_context, new_stdin, *args, **kwargs):
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
self._terminal = None self.ssh = None
self._cliconf = None self._ssh_shell = None
self._shell = None
self._matched_prompt = None self._matched_prompt = None
self._matched_pattern = None self._matched_pattern = None
self._last_response = None self._last_response = None
self._history = list() self._history = list()
self._play_context = play_context
if play_context.verbosity > 3: self._local = LocalConnection(play_context, new_stdin, *args, **kwargs)
self._terminal = None
self._cliconf = None
if self._play_context.verbosity > 3:
logging.getLogger('paramiko').setLevel(logging.DEBUG) logging.getLogger('paramiko').setLevel(logging.DEBUG)
# reconstruct the socket_path and set instance values accordingly
self._update_connection_state()
def __getattr__(self, name):
try:
return self.__dict__[name]
except KeyError:
if name.startswith('_'):
raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
return getattr(self._cliconf, name)
def exec_command(self, cmd, in_data=None, sudoable=True):
# this try..except block is just to handle the transition to supporting
# network_cli as a toplevel connection. Once connection=local is gone,
# this block can be removed as well and all calls passed directly to
# the local connection
if self._ssh_shell:
try:
cmd = json.loads(to_text(cmd, errors='surrogate_or_strict'))
kwargs = {'command': to_bytes(cmd['command'], errors='surrogate_or_strict')}
for key in ('prompts', 'answer', 'send_only'):
if key in cmd:
kwargs[key] = to_bytes(cmd[key], errors='surrogate_or_strict')
return self.send(**kwargs)
except ValueError:
cmd = to_bytes(cmd, errors='surrogate_or_strict')
return self.send(command=cmd)
else:
return self._local.exec_command(cmd, in_data, sudoable)
def put_file(self, in_path, out_path):
return self._local.put_file(in_path, out_path)
def fetch_file(self, in_path, out_path):
return self._local.fetch_file(in_path, out_path)
def update_play_context(self, play_context): def update_play_context(self, play_context):
"""Updates the play context information for the connection""" """Updates the play context information for the connection"""
display.display('updating play_context for connection', log_only=True) display.vvvv('updating play_context for connection', host=self._play_context.remote_addr)
if self._play_context.become is False and play_context.become is True: if self._play_context.become is False and play_context.become is True:
auth_pass = play_context.become_pass auth_pass = play_context.become_pass
@ -104,17 +149,22 @@ class Connection(Rpc, _Connection):
self._play_context = play_context self._play_context = play_context
def _connect(self): def _connect(self):
"""Connections to the device and sets the terminal type""" '''
Connects to the remote device and starts the terminal
'''
if self.connected:
return
if self._play_context.password and not self._play_context.private_key_file: if self._play_context.password and not self._play_context.private_key_file:
C.PARAMIKO_LOOK_FOR_KEYS = False C.PARAMIKO_LOOK_FOR_KEYS = False
super(Connection, self)._connect() ssh = ParamikoSshConnection(self._play_context, '/dev/null')._connect()
self.ssh = ssh.ssh
display.display('ssh connection done, setting terminal', log_only=True) display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr)
self._shell = self.ssh.invoke_shell() self._ssh_shell = self.ssh.invoke_shell()
self._shell.settimeout(self._play_context.timeout) self._ssh_shell.settimeout(self._play_context.timeout)
network_os = self._play_context.network_os network_os = self._play_context.network_os
if not network_os: if not network_os:
@ -127,53 +177,83 @@ class Connection(Rpc, _Connection):
if not self._terminal: if not self._terminal:
raise AnsibleConnectionFailure('network os %s is not supported' % network_os) raise AnsibleConnectionFailure('network os %s is not supported' % network_os)
display.display('loaded terminal plugin for network_os %s' % network_os, log_only=True) display.vvvv('loaded terminal plugin for network_os %s' % network_os, host=self._play_context.remote_addr)
self._cliconf = cliconf_loader.get(network_os, self) self._cliconf = cliconf_loader.get(network_os, self)
if self._cliconf: if self._cliconf:
self._rpc.add(self._cliconf) display.vvvv('loaded cliconf plugin for network_os %s' % network_os, host=self._play_context.remote_addr)
display.display('loaded cliconf plugin for network_os %s' % network_os, log_only=True)
else: else:
display.display('unable to load cliconf for network_os %s' % network_os) display.vvvv('unable to load cliconf for network_os %s' % network_os)
self.receive() self.receive()
display.display('firing event: on_open_shell()', log_only=True) display.vvvv('firing event: on_open_shell()', host=self._play_context.remote_addr)
self._terminal.on_open_shell() self._terminal.on_open_shell()
if getattr(self._play_context, 'become', None): if self._play_context.become and self._play_context.become_method == 'enable':
display.display('firing event: on_authorize', log_only=True) display.vvvv('firing event: on_authorize', host=self._play_context.remote_addr)
auth_pass = self._play_context.become_pass auth_pass = self._play_context.become_pass
self._terminal.on_authorize(passwd=auth_pass) self._terminal.on_authorize(passwd=auth_pass)
display.vvvv('ssh connection has completed successfully', host=self._play_context.remote_addr)
self._connected = True self._connected = True
display.display('ssh connection has completed successfully', log_only=True)
return self
def _update_connection_state(self):
'''
Reconstruct the connection socket_path and check if it exists
If the socket path exists then the connection is active and set
both the _socket_path value to the path and the _connected value
to True. If the socket path doesn't exist, leave the socket path
value to None and the _connected value to False
'''
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user)
tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
socket_path = unfrackpath(cp % dict(directory=tmp_path))
if os.path.exists(socket_path):
self._connected = True
self._socket_path = socket_path
def reset(self):
'''
Reset the connection
'''
if self._socket_path:
display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr)
self.shutdown()
def close(self): def close(self):
"""Close the active connection to the device '''
""" Close the active connection to the device
display.display("closing ssh connection to device", log_only=True) '''
if self._shell: # only close the connection if its connected.
display.display("firing event: on_close_shell()", log_only=True) if self._connected:
self._terminal.on_close_shell() display.debug("closing ssh connection to device")
self._shell.close() if self._ssh_shell:
self._shell = None display.debug("firing event: on_close_shell()")
display.display("cli session is now closed", log_only=True) self._terminal.on_close_shell()
self._ssh_shell.close()
super(Connection, self).close() self._ssh_shell = None
display.debug("cli session is now closed")
self._connected = False self._connected = False
display.display("ssh connection has been closed successfully", log_only=True) display.debug("ssh connection has been closed successfully")
def receive(self, command=None, prompts=None, answer=None): def receive(self, command=None, prompts=None, answer=None):
"""Handles receiving of output from command""" '''
Handles receiving of output from command
'''
recv = BytesIO() recv = BytesIO()
handled = False handled = False
self._matched_prompt = None self._matched_prompt = None
while True: while True:
data = self._shell.recv(256) data = self._ssh_shell.recv(256)
recv.write(data) recv.write(data)
offset = recv.tell() - 256 if recv.tell() > 256 else 0 offset = recv.tell() - 256 if recv.tell() > 256 else 0
@ -190,25 +270,30 @@ class Connection(Rpc, _Connection):
return self._sanitize(resp, command) return self._sanitize(resp, command)
def send(self, command, prompts=None, answer=None, send_only=False): def send(self, command, prompts=None, answer=None, send_only=False):
"""Sends the command to the device in the opened shell""" '''
Sends the command to the device in the opened shell
'''
try: try:
self._history.append(command) self._history.append(command)
self._shell.sendall(b'%s\r' % command) self._ssh_shell.sendall(b'%s\r' % command)
if send_only: if send_only:
return return
return self.receive(command, prompts, answer) response = self.receive(command, prompts, answer)
return to_text(response, errors='surrogate_or_strict')
except (socket.timeout, AttributeError): except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True) display.vvvv(traceback.format_exc(), host=self._play_context.remote_addr)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
def _strip(self, data): def _strip(self, data):
"""Removes ANSI codes from device response""" '''
Removes ANSI codes from device response
'''
for regex in self._terminal.ansi_re: for regex in self._terminal.ansi_re:
data = regex.sub(b'', data) data = regex.sub(b'', data)
return data return data
def _handle_prompt(self, resp, prompts, answer): def _handle_prompt(self, resp, prompts, answer):
""" '''
Matches the command prompt and responds Matches the command prompt and responds
:arg resp: Byte string containing the raw response from the remote :arg resp: Byte string containing the raw response from the remote
@ -216,17 +301,19 @@ class Connection(Rpc, _Connection):
:arg answer: Byte string to send back to the remote if we find a prompt. :arg answer: Byte string to send back to the remote if we find a prompt.
A carriage return is automatically appended to this string. A carriage return is automatically appended to this string.
:returns: True if a prompt was found in ``resp``. False otherwise :returns: True if a prompt was found in ``resp``. False otherwise
""" '''
prompts = [re.compile(r, re.I) for r in prompts] prompts = [re.compile(r, re.I) for r in prompts]
for regex in prompts: for regex in prompts:
match = regex.search(resp) match = regex.search(resp)
if match: if match:
self._shell.sendall(b'%s\r' % answer) self._ssh_shell.sendall(b'%s\r' % answer)
return True return True
return False return False
def _sanitize(self, resp, command=None): def _sanitize(self, resp, command=None):
"""Removes elements from the response before returning to the caller""" '''
Removes elements from the response before returning to the caller
'''
cleaned = [] cleaned = []
for line in resp.splitlines(): for line in resp.splitlines():
if (command and line.strip() == command.strip()) or self._matched_prompt.strip() in line: if (command and line.strip() == command.strip()) or self._matched_prompt.strip() in line:
@ -235,7 +322,8 @@ class Connection(Rpc, _Connection):
return b'\n'.join(cleaned).strip() return b'\n'.join(cleaned).strip()
def _find_prompt(self, response): def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt""" '''Searches the buffered response for a matching command prompt
'''
errored_response = None errored_response = None
is_error_message = False is_error_message = False
for regex in self._terminal.terminal_stderr_re: for regex in self._terminal.terminal_stderr_re:
@ -264,64 +352,3 @@ class Connection(Rpc, _Connection):
raise AnsibleConnectionFailure(errored_response) raise AnsibleConnectionFailure(errored_response)
return False return False
def alarm_handler(self, signum, frame):
"""Alarm handler raised in case of command timeout """
display.display('closing shell due to sigalarm', log_only=True)
self.close()
def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output
The method accepts three forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The
second form is as a utf8 JSON byte string with additional keywords.
The third form is a json-rpc (2.0)
Keywords supported for cmd:
:command: the command string to execute
:prompt: the expected prompt generated by executing command.
This can be a string or a list of strings
:answer: the string to respond to the prompt with
:sendonly: bool to disable waiting for response
:arg cmd: the byte string that represents the command to be executed
which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return
code is an integer and stdout and stderr are byte strings
"""
try:
obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
except (ValueError, TypeError):
obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
if 'prompt' in obj:
if isinstance(obj['prompt'], binary_type):
# Prompt was a string
obj['prompt'] = [obj['prompt']]
elif not isinstance(obj['prompt'], Sequence):
# Convert nonstrings into byte strings (to_bytes(5) => b'5')
if obj['prompt'] is not None:
obj['prompt'] = [to_bytes(obj['prompt'], errors='surrogate_or_strict')]
else:
# Prompt was a Sequence of strings. Make sure they're byte strings
obj['prompt'] = [to_bytes(p, errors='surrogate_or_strict') for p in obj['prompt'] if p is not None]
if 'jsonrpc' in obj:
if self._cliconf:
out = self._exec_rpc(obj)
else:
out = self.internal_error("cliconf is not supported for network_os %s" % self._play_context.network_os)
return 0, to_bytes(out, errors='surrogate_or_strict'), b''
if obj['command'] == b'prompt()':
return 0, self._matched_prompt, b''
try:
if not signal.getsignal(signal.SIGALRM):
signal.signal(signal.SIGALRM, self.alarm_handler)
signal.alarm(self._play_context.timeout)
out = self.send(obj['command'], obj.get('prompt'), obj.get('answer'), obj.get('sendonly'))
signal.alarm(0)
return 0, out, b''
except (AnsibleConnectionFailure, ValueError) as exc:
return 1, b'', to_bytes(exc)

View file

@ -100,6 +100,7 @@ with warnings.catch_warnings():
class MyAddPolicy(object): class MyAddPolicy(object):
""" """
Based on AutoAddPolicy in paramiko so we can determine when keys are added Based on AutoAddPolicy in paramiko so we can determine when keys are added
and also prompt for input. and also prompt for input.
Policy for automatically adding the hostname and new host key to the Policy for automatically adding the hostname and new host key to the
@ -114,8 +115,13 @@ class MyAddPolicy(object):
if all((C.HOST_KEY_CHECKING, not C.PARAMIKO_HOST_KEY_AUTO_ADD)): if all((C.HOST_KEY_CHECKING, not C.PARAMIKO_HOST_KEY_AUTO_ADD)):
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
if C.USE_PERSISTENT_CONNECTIONS: if C.USE_PERSISTENT_CONNECTIONS:
raise AnsibleConnectionFailure('rejected %s host key for host %s: %s' % (key.get_name(), hostname, hexlify(key.get_fingerprint()))) # don't print the prompt string since the user cannot respond
# to the question anyway
raise AnsibleError(AUTHENTICITY_MSG[1:92] % (hostname, ktype, fingerprint))
self.connection.connection_lock() self.connection.connection_lock()
@ -125,9 +131,6 @@ class MyAddPolicy(object):
# clear out any premature input on sys.stdin # clear out any premature input on sys.stdin
tcflush(sys.stdin, TCIFLUSH) tcflush(sys.stdin, TCIFLUSH)
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
inp = input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint)) inp = input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin sys.stdin = old_stdin

View file

@ -1,4 +1,4 @@
# (c) 2017 Red Hat Inc. # 2017 Red Hat Inc.
# (c) 2017 Ansible Project # (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
@ -13,15 +13,19 @@ DOCUMENTATION = """
- This is a helper plugin to allow making other connections persistent. - This is a helper plugin to allow making other connections persistent.
version_added: "2.3" version_added: "2.3"
""" """
import re
import os import os
import sys
import pty import pty
import json
import subprocess import subprocess
from ansible.module_utils._text import to_bytes, to_text from ansible import constants as C
from ansible.module_utils.six.moves import cPickle from ansible.plugins.loader import connection_loader
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
from ansible.module_utils._text import to_text
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import Connection as SocketConnection
from ansible.errors import AnsibleError
try: try:
from __main__ import display from __main__ import display
@ -40,8 +44,38 @@ class Connection(ConnectionBase):
self._connected = True self._connected = True
return self return self
def _do_it(self, action): def exec_command(self, cmd, in_data=None, sudoable=True):
display.vvvv('exec_command(), socket_path=%s' % self.socket_path, host=self._play_context.remote_addr)
connection = SocketConnection(self.socket_path)
out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable)
return 0, out, ''
def put_file(self, in_path, out_path):
pass
def fetch_file(self, in_path, out_path):
pass
def close(self):
self._connected = False
def run(self):
"""Returns the path of the persistent connection socket.
Attempts to ensure (within playcontext.timeout seconds) that the
socket path exists. If the path exists (or the timeout has expired),
returns the socket path.
"""
display.vvvv('starting connection from persistent connection plugin', host=self._play_context.remote_addr)
socket_path = self._start_connection()
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(self, '_socket_path', socket_path)
return socket_path
def _start_connection(self):
'''
Starts the persistent connection
'''
master, slave = pty.openpty() master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE) p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0) stdin = os.fdopen(master, 'wb', 0)
@ -56,40 +90,23 @@ class Connection(ConnectionBase):
stdin.write(src) stdin.write(src)
stdin.write(b'\n#END_INIT#\n') stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action))
stdin.write(b'\n\n')
(stdout, stderr) = p.communicate() (stdout, stderr) = p.communicate()
stdin.close() stdin.close()
return (p.returncode, stdout, stderr) if p.returncode == 0:
result = json.loads(to_text(stdout, errors='surrogate_then_replace'))
else:
result = json.loads(to_text(stderr, errors='surrogate_then_replace'))
def exec_command(self, cmd, in_data=None, sudoable=True): if 'messages' in result:
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) for msg in result.get('messages'):
return self._do_it('EXEC: ' + cmd) display.vvvv('%s' % msg, host=self._play_context.remote_addr)
def put_file(self, in_path, out_path): if 'error' in result:
super(Connection, self).put_file(in_path, out_path) if self._play_context.verbosity > 2:
self._do_it('PUT: %s %s' % (in_path, out_path)) msg = "The full traceback is:\n" + result['exception']
display.display(result['exception'], color=C.COLOR_ERROR)
raise AnsibleError(result['error'])
def fetch_file(self, in_path, out_path): return result['socket_path']
super(Connection, self).fetch_file(in_path, out_path)
self._do_it('FETCH: %s %s' % (in_path, out_path))
def close(self):
self._connected = False
def run(self):
"""Returns the path of the persistent connection socket.
Attempts to ensure (within playcontext.timeout seconds) that the
socket path exists. If the path exists (or the timeout has expired),
returns the socket path.
"""
socket_path = None
rc, out, err = self._do_it('RUN:')
match = re.search(br"#SOCKET_PATH#: (\S+)", out)
if match:
socket_path = to_text(match.group(1).strip(), errors='surrogate_or_strict')
return socket_path

View file

@ -56,20 +56,12 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
self._connection = connection self._connection = connection
def _exec_cli_command(self, cmd, check_rc=True): def _exec_cli_command(self, cmd, check_rc=True):
""" '''
Executes a CLI command on the device Executes the CLI command on the remote device and returns the output
:arg cmd: Byte string consisting of the command to execute :arg cmd: Byte string command to be executed
:kwarg check_rc: If True, the default, raise an '''
:exc:`AnsibleConnectionFailure` if the return code from the return self._connection.exec_command(cmd)
command is nonzero
:returns: A tuple of return code, stdout, and stderr from running the
command. stdout and stderr are both byte strings.
"""
rc, out, err = self._connection.exec_command(cmd)
if check_rc and rc != 0:
raise AnsibleConnectionFailure(err)
return rc, out, err
def _get_prompt(self): def _get_prompt(self):
""" """
@ -77,9 +69,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
:returns: A byte string of the prompt :returns: A byte string of the prompt
""" """
for cmd in (b'\n', b'prompt()'): self._exec_cli_command(b'\n')
rc, out, err = self._exec_cli_command(cmd) return self._connection._matched_prompt
return out
def on_open_shell(self): def on_open_shell(self):
"""Called after the SSH session is established """Called after the SSH session is established

View file

@ -36,21 +36,21 @@ except ImportError:
class TerminalModule(TerminalBase): class TerminalModule(TerminalBase):
terminal_stdout_re = [ terminal_stdout_re = [
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"), re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"),
] ]
terminal_stderr_re = [ terminal_stderr_re = [
re.compile(r"unknown command"), re.compile(br"unknown command"),
re.compile(r"syntax error,") re.compile(br"syntax error,")
] ]
def on_open_shell(self): def on_open_shell(self):
try: try:
prompt = self._get_prompt() prompt = self._get_prompt()
if prompt.strip().endswith('%'): if prompt.strip().endswith(b'%'):
display.vvv('starting cli', self._connection._play_context.remote_addr) display.vvv('starting cli', self._connection._play_context.remote_addr)
self._exec_cli_command('cli') self._exec_cli_command('cli')
for c in ['set cli timestamp disable', 'set cli screen-length 0', 'set cli screen-width 1024']: for c in (b'set cli timestamp disable', b'set cli screen-length 0', b'set cli screen-width 1024'):
self._exec_cli_command(c) self._exec_cli_command(c)
except AnsibleConnectionFailure: except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to set terminal parameters') raise AnsibleConnectionFailure('unable to set terminal parameters')

View file

@ -1,28 +1,16 @@
# # (c) 2017, Peter Sprygada <psprygad@redhat.com>
# (c) 2016 Red Hat Inc. # (c) 2017 Ansible Project
# # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import json import json
import traceback import traceback
from ansible import constants as C
from ansible.module_utils._text import to_text from ansible.module_utils._text import to_text
try: try:
from __main__ import display from __main__ import display
except ImportError: except ImportError:
@ -30,13 +18,13 @@ except ImportError:
display = Display() display = Display()
class Rpc: class JsonRpcServer(object):
def __init__(self, *args, **kwargs): _objects = set()
self._rpc = set()
super(Rpc, self).__init__(*args, **kwargs) def handle_request(self, request):
request = json.loads(to_text(request, errors='surrogate_then_replace'))
def _exec_rpc(self, request):
method = request.get('method') method = request.get('method')
if method.startswith('rpc.') or method.startswith('_'): if method.startswith('rpc.') or method.startswith('_'):
@ -45,6 +33,7 @@ class Rpc:
params = request.get('params') params = request.get('params')
setattr(self, '_identifier', request.get('id')) setattr(self, '_identifier', request.get('id'))
args = [] args = []
kwargs = {} kwargs = {}
@ -54,10 +43,15 @@ class Rpc:
kwargs = params kwargs = params
rpc_method = None rpc_method = None
for obj in self._rpc:
rpc_method = getattr(obj, method, None) if method in ('shutdown', 'reset'):
if rpc_method: rpc_method = getattr(self, 'shutdown')
break
else:
for obj in self._objects:
rpc_method = getattr(obj, method, None)
if rpc_method:
break
if not rpc_method: if not rpc_method:
error = self.method_not_found() error = self.method_not_found()
@ -66,7 +60,7 @@ class Rpc:
try: try:
result = rpc_method(*args, **kwargs) result = rpc_method(*args, **kwargs)
except Exception as exc: except Exception as exc:
display.display(traceback.format_exc(), log_only=True) display.vvv(traceback.format_exc())
error = self.internal_error(data=to_text(exc, errors='surrogate_then_replace')) error = self.internal_error(data=to_text(exc, errors='surrogate_then_replace'))
response = json.dumps(error) response = json.dumps(error)
else: else:
@ -78,8 +72,12 @@ class Rpc:
response = json.dumps(response) response = json.dumps(response)
delattr(self, '_identifier') delattr(self, '_identifier')
return response return response
def register(self, obj):
self._objects.add(obj)
def header(self): def header(self):
return {'jsonrpc': '2.0', 'id': self._identifier} return {'jsonrpc': '2.0', 'id': self._identifier}

View file

@ -405,6 +405,7 @@ class TestActionBase(unittest.TestCase):
mock_connection = MagicMock() mock_connection = MagicMock()
mock_connection.build_module_command.side_effect = build_module_command mock_connection.build_module_command.side_effect = build_module_command
mock_connection.socket_path = None
mock_connection._shell.get_remote_filename.return_value = 'copy.py' mock_connection._shell.get_remote_filename.return_value = 'copy.py'
mock_connection._shell.join_path.side_effect = os.path.join mock_connection._shell.join_path.side_effect = os.path.join

View file

@ -37,6 +37,7 @@ from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoConnec
from ansible.plugins.connection.ssh import Connection as SSHConnection from ansible.plugins.connection.ssh import Connection as SSHConnection
from ansible.plugins.connection.docker import Connection as DockerConnection from ansible.plugins.connection.docker import Connection as DockerConnection
# from ansible.plugins.connection.winrm import Connection as WinRmConnection # from ansible.plugins.connection.winrm import Connection as WinRmConnection
from ansible.plugins.connection.netconf import Connection as NetconfConnection
from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection
@ -140,7 +141,9 @@ class TestConnectionBaseClass(unittest.TestCase):
def test_network_cli_connection_module(self): def test_network_cli_connection_module(self):
self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection) self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection)
self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), ParamikoConnection)
def test_netconf_connection_module(self):
self.assertIsInstance(NetconfConnection(self.play_context, self.in_stream), NetconfConnection)
def test_check_password_prompt(self): def test_check_password_prompt(self):
local = ( local = (

View file

@ -69,9 +69,9 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin) conn = netconf.Connection(pc, new_stdin)
mock_manager = MagicMock(name='self._manager.connect') mock_manager = MagicMock()
type(mock_manager).session_id = PropertyMock(return_value='123456789') mock_manager.session_id = '123456789'
netconf.manager.connect.return_value = mock_manager netconf.manager.connect = MagicMock(return_value=mock_manager)
conn._play_context.network_os = 'default' conn._play_context.network_os = 'default'
rc, out, err = conn._connect() rc, out, err = conn._connect()
@ -88,22 +88,16 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin) conn = netconf.Connection(pc, new_stdin)
conn._connected = True conn._connected = True
mock_manager = MagicMock(name='self._manager')
mock_reply = MagicMock(name='reply') mock_reply = MagicMock(name='reply')
type(mock_reply).data_xml = PropertyMock(return_value='<test/>') type(mock_reply).data_xml = PropertyMock(return_value='<test/>')
mock_manager = MagicMock(name='self._manager')
mock_manager.rpc.return_value = mock_reply mock_manager.rpc.return_value = mock_reply
conn._manager = mock_manager conn._manager = mock_manager
rc, out, err = conn.exec_command('<test/>') out = conn.exec_command('<test/>')
netconf.to_ele.assert_called_with('<test/>') self.assertEqual('<test/>', out)
self.assertEqual(0, rc)
self.assertEqual(b'<test/>', out)
self.assertEqual(b'', err)
def test_netconf_exec_command_invalid_request(self): def test_netconf_exec_command_invalid_request(self):
pc = PlayContext() pc = PlayContext()
@ -112,10 +106,11 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin) conn = netconf.Connection(pc, new_stdin)
conn._connected = True conn._connected = True
mock_manager = MagicMock(name='self._manager')
conn._manager = mock_manager
netconf.to_ele.return_value = None netconf.to_ele.return_value = None
rc, out, err = conn.exec_command('test string') out = conn.exec_command('test string')
self.assertEqual(1, rc) self.assertEqual('unable to parse request', out)
self.assertEqual(b'', out)
self.assertEqual(b'unable to parse request', err)

View file

@ -35,7 +35,7 @@ from ansible.plugins.connection import network_cli
class TestConnectionClass(unittest.TestCase): class TestConnectionClass(unittest.TestCase):
@patch("ansible.plugins.connection.network_cli._Connection._connect") @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__connect_error(self, mocked_super): def test_network_cli__connect_error(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -47,7 +47,7 @@ class TestConnectionClass(unittest.TestCase):
pc.network_os = None pc.network_os = None
self.assertRaises(AnsibleConnectionFailure, conn._connect) self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli._Connection._connect") @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__invalid_os(self, mocked_super): def test_network_cli__invalid_os(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -60,7 +60,7 @@ class TestConnectionClass(unittest.TestCase):
self.assertRaises(AnsibleConnectionFailure, conn._connect) self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli.terminal_loader") @patch("ansible.plugins.connection.network_cli.terminal_loader")
@patch("ansible.plugins.connection.network_cli._Connection._connect") @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): def test_network_cli__connect(self, mocked_super, mocked_terminal_loader):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -70,22 +70,21 @@ class TestConnectionClass(unittest.TestCase):
conn.ssh = MagicMock() conn.ssh = MagicMock()
conn.receive = MagicMock() conn.receive = MagicMock()
conn._terminal = MagicMock()
mock_terminal = MagicMock()
conn._terminal = mock_terminal
conn._connect() conn._connect()
self.assertTrue(conn._terminal.on_open_shell.called) self.assertTrue(conn._terminal.on_open_shell.called)
self.assertFalse(conn._terminal.on_authorize.called) self.assertFalse(conn._terminal.on_authorize.called)
conn._play_context.become = True conn._play_context.become = True
conn._play_context.become_method = 'enable'
conn._play_context.become_pass = 'password' conn._play_context.become_pass = 'password'
conn._connected = False
conn._connect() conn._connect()
conn._terminal.on_authorize.assert_called_with(passwd='password') conn._terminal.on_authorize.assert_called_with(passwd='password')
@patch("ansible.plugins.connection.network_cli._Connection.close") @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection.close")
def test_network_cli_close(self, mocked_super): def test_network_cli_close(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -93,20 +92,14 @@ class TestConnectionClass(unittest.TestCase):
terminal = MagicMock(supports_multiplexing=False) terminal = MagicMock(supports_multiplexing=False)
conn._terminal = terminal conn._terminal = terminal
conn._ssh_shell = MagicMock()
conn.close() conn._connected = True
conn._shell = MagicMock()
conn.close() conn.close()
self.assertTrue(terminal.on_close_shell.called) self.assertTrue(terminal.on_close_shell.called)
self.assertIsNone(conn._ssh_shell)
terminal.supports_multiplexing = True @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
conn.close()
self.assertIsNone(conn._shell)
@patch("ansible.plugins.connection.network_cli._Connection._connect")
def test_network_cli_exec_command(self, mocked_super): def test_network_cli_exec_command(self, mocked_super):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
@ -114,23 +107,17 @@ class TestConnectionClass(unittest.TestCase):
mock_send = MagicMock(return_value=b'command response') mock_send = MagicMock(return_value=b'command response')
conn.send = mock_send conn.send = mock_send
conn._ssh_shell = MagicMock()
# test sending a single command and converting to dict # test sending a single command and converting to dict
rc, out, err = conn.exec_command('command') out = conn.exec_command('command')
self.assertEqual(out, b'command response') self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None) mock_send.assert_called_with(command=b'command')
# test sending a json string # test sending a json string
rc, out, err = conn.exec_command(json.dumps({'command': 'command'})) out = conn.exec_command(json.dumps({'command': 'command'}))
self.assertEqual(out, b'command response') self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None) mock_send.assert_called_with(command=b'command')
conn._shell = MagicMock()
# test _shell already open
rc, out, err = conn.exec_command('command')
self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None)
def test_network_cli_send(self): def test_network_cli_send(self):
pc = PlayContext() pc = PlayContext()
@ -142,7 +129,7 @@ class TestConnectionClass(unittest.TestCase):
conn._terminal = mock__terminal conn._terminal = mock__terminal
mock__shell = MagicMock() mock__shell = MagicMock()
conn._shell = mock__shell conn._ssh_shell = mock__shell
response = b"""device#command response = b"""device#command
command response command response
@ -155,7 +142,7 @@ class TestConnectionClass(unittest.TestCase):
output = conn.send(b'command', None, None, None) output = conn.send(b'command', None, None, None)
mock__shell.sendall.assert_called_with(b'command\r') mock__shell.sendall.assert_called_with(b'command\r')
self.assertEqual(output, b'command response') self.assertEqual(output, 'command response')
mock__shell.reset_mock() mock__shell.reset_mock()
mock__shell.recv.return_value = b"ERROR: error message device#" mock__shell.recv.return_value = b"ERROR: error message device#"