Use a decorator to ensure jit connection, instead of an explicit call to _connect
This commit is contained in:
parent
f7839dee11
commit
9754c67138
5 changed files with 27 additions and 6 deletions
|
@ -210,7 +210,6 @@ class TaskExecutor:
|
||||||
# get the connection and the handler for this execution
|
# get the connection and the handler for this execution
|
||||||
self._connection = self._get_connection(variables)
|
self._connection = self._get_connection(variables)
|
||||||
self._connection.set_host_overrides(host=self._host)
|
self._connection.set_host_overrides(host=self._host)
|
||||||
self._connection._connect()
|
|
||||||
|
|
||||||
self._handler = self._get_action_handler(connection=self._connection, templar=templar)
|
self._handler = self._get_action_handler(connection=self._connection, templar=templar)
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ __metaclass__ = type
|
||||||
|
|
||||||
from abc import ABCMeta, abstractmethod, abstractproperty
|
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
from six import with_metaclass
|
from six import with_metaclass
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
|
@ -32,7 +33,16 @@ from ansible.errors import AnsibleError
|
||||||
# which may want to output display/logs too
|
# which may want to output display/logs too
|
||||||
from ansible.utils.display import Display
|
from ansible.utils.display import Display
|
||||||
|
|
||||||
__all__ = ['ConnectionBase']
|
__all__ = ['ConnectionBase', 'ensure_connect']
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_connect(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped(self, *args, **kwargs):
|
||||||
|
self._connect()
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
class ConnectionBase(with_metaclass(ABCMeta, object)):
|
class ConnectionBase(with_metaclass(ABCMeta, object)):
|
||||||
'''
|
'''
|
||||||
|
|
|
@ -41,7 +41,7 @@ from binascii import hexlify
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
||||||
from ansible.plugins.connections import ConnectionBase
|
from ansible.plugins.connections import ConnectionBase, ensure_connect
|
||||||
from ansible.utils.path import makedirs_safe
|
from ansible.utils.path import makedirs_safe
|
||||||
|
|
||||||
AUTHENTICITY_MSG="""
|
AUTHENTICITY_MSG="""
|
||||||
|
@ -61,6 +61,7 @@ with warnings.catch_warnings():
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MyAddPolicy(object):
|
class MyAddPolicy(object):
|
||||||
"""
|
"""
|
||||||
Based on AutoAddPolicy in paramiko so we can determine when keys are added
|
Based on AutoAddPolicy in paramiko so we can determine when keys are added
|
||||||
|
@ -188,6 +189,7 @@ class Connection(ConnectionBase):
|
||||||
|
|
||||||
return ssh
|
return ssh
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
||||||
''' run a command on the remote host '''
|
''' run a command on the remote host '''
|
||||||
|
|
||||||
|
@ -248,6 +250,7 @@ class Connection(ConnectionBase):
|
||||||
|
|
||||||
return (chan.recv_exit_status(), '', no_prompt_out + stdout, no_prompt_out + stderr)
|
return (chan.recv_exit_status(), '', no_prompt_out + stdout, no_prompt_out + stderr)
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def put_file(self, in_path, out_path):
|
def put_file(self, in_path, out_path):
|
||||||
''' transfer a file from local to remote '''
|
''' transfer a file from local to remote '''
|
||||||
|
|
||||||
|
@ -272,9 +275,10 @@ class Connection(ConnectionBase):
|
||||||
if cache_key in SFTP_CONNECTION_CACHE:
|
if cache_key in SFTP_CONNECTION_CACHE:
|
||||||
return SFTP_CONNECTION_CACHE[cache_key]
|
return SFTP_CONNECTION_CACHE[cache_key]
|
||||||
else:
|
else:
|
||||||
result = SFTP_CONNECTION_CACHE[cache_key] = self.connect().ssh.open_sftp()
|
result = SFTP_CONNECTION_CACHE[cache_key] = self._connect().ssh.open_sftp()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def fetch_file(self, in_path, out_path):
|
def fetch_file(self, in_path, out_path):
|
||||||
''' save a remote file to the specified path '''
|
''' save a remote file to the specified path '''
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ from hashlib import sha1
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
||||||
from ansible.plugins.connections import ConnectionBase
|
from ansible.plugins.connections import ConnectionBase, ensure_connect
|
||||||
|
|
||||||
|
|
||||||
class Connection(ConnectionBase):
|
class Connection(ConnectionBase):
|
||||||
''' ssh based connections '''
|
''' ssh based connections '''
|
||||||
|
@ -269,6 +270,7 @@ class Connection(ConnectionBase):
|
||||||
self._display.vvv("EXEC previous known host file not found for {0}".format(host))
|
self._display.vvv("EXEC previous known host file not found for {0}".format(host))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
||||||
''' run a command on the remote host '''
|
''' run a command on the remote host '''
|
||||||
|
|
||||||
|
@ -390,6 +392,7 @@ class Connection(ConnectionBase):
|
||||||
|
|
||||||
return (p.returncode, '', no_prompt_out + stdout, no_prompt_err + stderr)
|
return (p.returncode, '', no_prompt_out + stdout, no_prompt_err + stderr)
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def put_file(self, in_path, out_path):
|
def put_file(self, in_path, out_path):
|
||||||
''' transfer a file from local to remote '''
|
''' transfer a file from local to remote '''
|
||||||
self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
|
self._display.vvv("PUT {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
|
||||||
|
@ -425,6 +428,7 @@ class Connection(ConnectionBase):
|
||||||
if returncode != 0:
|
if returncode != 0:
|
||||||
raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr))
|
raise AnsibleError("failed to transfer file to {0}:\n{1}\n{2}".format(out_path, stdout, stderr))
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def fetch_file(self, in_path, out_path):
|
def fetch_file(self, in_path, out_path):
|
||||||
''' fetch a file from remote to local '''
|
''' fetch a file from remote to local '''
|
||||||
self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
|
self._display.vvv("FETCH {0} TO {1}".format(in_path, out_path), host=self._connection_info.remote_addr)
|
||||||
|
|
|
@ -42,10 +42,11 @@ except ImportError:
|
||||||
|
|
||||||
from ansible import constants as C
|
from ansible import constants as C
|
||||||
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound
|
||||||
from ansible.plugins.connections import ConnectionBase
|
from ansible.plugins.connections import ConnectionBase, ensure_connect
|
||||||
from ansible.plugins import shell_loader
|
from ansible.plugins import shell_loader
|
||||||
from ansible.utils.path import makedirs_safe
|
from ansible.utils.path import makedirs_safe
|
||||||
|
|
||||||
|
|
||||||
class Connection(ConnectionBase):
|
class Connection(ConnectionBase):
|
||||||
'''WinRM connections over HTTP/HTTPS.'''
|
'''WinRM connections over HTTP/HTTPS.'''
|
||||||
|
|
||||||
|
@ -151,6 +152,7 @@ class Connection(ConnectionBase):
|
||||||
self.protocol = self._winrm_connect()
|
self.protocol = self._winrm_connect()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
def exec_command(self, cmd, tmp_path, executable='/bin/sh', in_data=None):
|
||||||
|
|
||||||
cmd = cmd.encode('utf-8')
|
cmd = cmd.encode('utf-8')
|
||||||
|
@ -172,6 +174,7 @@ class Connection(ConnectionBase):
|
||||||
raise AnsibleError("failed to exec cmd %s" % cmd)
|
raise AnsibleError("failed to exec cmd %s" % cmd)
|
||||||
return (result.status_code, '', result.std_out.encode('utf-8'), result.std_err.encode('utf-8'))
|
return (result.status_code, '', result.std_out.encode('utf-8'), result.std_err.encode('utf-8'))
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def put_file(self, in_path, out_path):
|
def put_file(self, in_path, out_path):
|
||||||
self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
||||||
if not os.path.exists(in_path):
|
if not os.path.exists(in_path):
|
||||||
|
@ -210,6 +213,7 @@ class Connection(ConnectionBase):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise AnsibleError("failed to transfer file to %s" % out_path)
|
raise AnsibleError("failed to transfer file to %s" % out_path)
|
||||||
|
|
||||||
|
@ensure_connect
|
||||||
def fetch_file(self, in_path, out_path):
|
def fetch_file(self, in_path, out_path):
|
||||||
out_path = out_path.replace('\\', '/')
|
out_path = out_path.replace('\\', '/')
|
||||||
self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._connection_info.remote_addr)
|
||||||
|
|
Loading…
Reference in a new issue