Adding a persistent connection utility
This commit is contained in:
parent
0b96d61162
commit
26ec2ecfce
6 changed files with 428 additions and 4 deletions
302
bin/ansible-connection
Executable file
302
bin/ansible-connection
Executable file
|
@ -0,0 +1,302 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# (c) 2016, Ansible, Inc. <support@ansible.com>
|
||||
#
|
||||
# This file is part of Ansible
|
||||
#
|
||||
# Ansible is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Ansible is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
########################################################
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
__requires__ = ['ansible']
|
||||
try:
|
||||
import pkg_resources
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import fcntl
|
||||
import hashlib
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
#import q
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from ansible import constants as C
|
||||
from ansible.module_utils._text import to_bytes, to_native
|
||||
from ansible.module_utils.six.moves import cPickle, StringIO
|
||||
from ansible.playbook.play_context import PlayContext
|
||||
from ansible.plugins import connection_loader
|
||||
from ansible.utils.path import unfrackpath, makedirs_safe
|
||||
|
||||
def do_fork():
|
||||
'''
|
||||
Does the required double fork for a daemon process. Based on
|
||||
http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
|
||||
'''
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
return pid
|
||||
|
||||
os.chdir("/")
|
||||
os.setsid()
|
||||
os.umask(0)
|
||||
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
sys.exit(0)
|
||||
|
||||
os.close(sys.stdin.fileno())
|
||||
os.close(sys.stdout.fileno())
|
||||
os.close(sys.stderr.fileno())
|
||||
|
||||
return pid
|
||||
except OSError as e:
|
||||
sys.exit(1)
|
||||
except OSError as e:
|
||||
sys.exit(1)
|
||||
|
||||
def send_data(s, data):
|
||||
packed_len = struct.pack('!Q',len(data))
|
||||
return s.sendall(packed_len + data)
|
||||
|
||||
def recv_data(s):
|
||||
header_len = 8 # size of a packed unsigned long long
|
||||
data = b""
|
||||
while len(data) < header_len:
|
||||
d = s.recv(header_len - len(data))
|
||||
if not d:
|
||||
return None
|
||||
data += d
|
||||
data_len = struct.unpack('!Q',data[:header_len])[0]
|
||||
data = data[header_len:]
|
||||
while len(data) < data_len:
|
||||
d = s.recv(data_len - len(data))
|
||||
if not d:
|
||||
return None
|
||||
data += d
|
||||
return data
|
||||
|
||||
class Server():
|
||||
def __init__(self, path, play_context):
|
||||
self.path = path
|
||||
self.play_context = play_context
|
||||
|
||||
# FIXME: the connection loader here is created brand new,
|
||||
# so it will not see any custom paths loaded (ie. via
|
||||
# roles), so we will need to serialize the connection
|
||||
# loader and send it over as we do the PlayContext
|
||||
# in the main() method below.
|
||||
self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin)
|
||||
self.conn._connect()
|
||||
|
||||
#q.q("done setting up connection and connected")
|
||||
|
||||
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.socket.bind(path)
|
||||
self.socket.listen(1)
|
||||
|
||||
signal.signal(signal.SIGALRM, self.alarm_handler)
|
||||
|
||||
def alarm_handler(self, signum, frame):
|
||||
'''
|
||||
Alarm handler
|
||||
'''
|
||||
# FIXME: this should also set internal flags for other
|
||||
# areas of code to check, so they can terminate
|
||||
# earlier than the socket going back to the accept
|
||||
# call and failing there.
|
||||
self.socket.close()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
# set the alarm, if we don't get an accept before it
|
||||
# goes off we exit (via an exception caused by the socket
|
||||
# getting closed while waiting on accept())
|
||||
# FIXME: is this the best way to exit? as noted above in the
|
||||
# handler we should probably be setting a flag to check
|
||||
# here and in other parts of the code
|
||||
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
|
||||
try:
|
||||
(s, addr) = self.socket.accept()
|
||||
# clear the alarm
|
||||
# FIXME: potential race condition here between the accept and
|
||||
# time to this call.
|
||||
signal.alarm(0)
|
||||
except:
|
||||
break
|
||||
|
||||
while True:
|
||||
data = recv_data(s)
|
||||
if not data:
|
||||
break
|
||||
|
||||
rc = 255
|
||||
try:
|
||||
if data.startswith(b'EXEC: '):
|
||||
cmd = data.split(b'EXEC: ')[1]
|
||||
(rc, stdout, stderr) = self.conn.exec_command(cmd)
|
||||
elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '):
|
||||
(op, src, dst) = shlex.split(to_native(data))
|
||||
stdout = stderr = ''
|
||||
try:
|
||||
if op == 'FETCH:':
|
||||
self.conn.fetch_file(src, dst)
|
||||
elif op == 'PUT:':
|
||||
self.conn.put_file(src, dst)
|
||||
rc = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
stdout = ''
|
||||
stderr = 'Invalid action specified'
|
||||
except:
|
||||
stdout = ''
|
||||
stderr = traceback.format_exc()
|
||||
|
||||
send_data(s, to_bytes(str(rc)))
|
||||
send_data(s, to_bytes(stdout))
|
||||
send_data(s, to_bytes(stderr))
|
||||
s.close()
|
||||
except Exception as e:
|
||||
# FIXME: proper logging and error handling here
|
||||
print("run exception: %s" % e)
|
||||
print(traceback.format_exc())
|
||||
finally:
|
||||
# when done, close the connection properly and cleanup
|
||||
# the socket file so it can be recreated
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
os.remove(self.path)
|
||||
|
||||
def main():
|
||||
try:
|
||||
# read the play context data via stdin, which means depickling it
|
||||
# FIXME: as noted above, we will probably need to deserialize the
|
||||
# connection loader here as well at some point, otherwise this
|
||||
# won't find role- or playbook-based connection plugins
|
||||
cur_line = sys.stdin.readline()
|
||||
init_data = ''
|
||||
while cur_line.strip() != '#END_INIT#':
|
||||
if cur_line == '':
|
||||
raise Exception("EOL found before init data was complete")
|
||||
init_data += cur_line
|
||||
cur_line = sys.stdin.readline()
|
||||
src = BytesIO(to_bytes(init_data))
|
||||
pc_data = cPickle.load(src)
|
||||
src.close()
|
||||
|
||||
pc = PlayContext()
|
||||
pc.deserialize(pc_data)
|
||||
except Exception as e:
|
||||
# FIXME: better error message/handling/logging
|
||||
print("FAIL: %s" % e)
|
||||
print(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
#q.q("done reading in and parsing PlayContext")
|
||||
|
||||
# here we create a hash to use later when creating the socket file,
|
||||
# so we can hide the info about the target host/user/etc.
|
||||
m = hashlib.sha256()
|
||||
for attr in ('connection', 'remote_addr', 'port', 'remote_user'):
|
||||
val = getattr(pc, attr, None)
|
||||
if val:
|
||||
m.update(to_bytes(val))
|
||||
|
||||
# create the persistent connection dir if need be and create the paths
|
||||
# which we will be using later
|
||||
tmp_path = unfrackpath("$HOME/.ansible/pc")
|
||||
makedirs_safe(tmp_path)
|
||||
lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
|
||||
sf_path = unfrackpath("%s/conn-%s" % (tmp_path, m.hexdigest()[0:12]))
|
||||
|
||||
# if the socket file doesn't exist, spin up the daemon process
|
||||
lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600)
|
||||
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
|
||||
if not os.path.exists(sf_path):
|
||||
#q.q("creating daemonized connection fork")
|
||||
pid = do_fork()
|
||||
if pid == 0:
|
||||
server = Server(sf_path, pc)
|
||||
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
|
||||
os.close(lock_fd)
|
||||
#q.q("fork done, running server")
|
||||
server.run()
|
||||
#q.q("server run complete, exiting")
|
||||
sys.exit(0)
|
||||
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
|
||||
os.close(lock_fd)
|
||||
|
||||
# now connect to the daemon process
|
||||
# FIXME: if the socket file existed but the daemonized process was killed,
|
||||
# the connection will timeout here. Need to make this more resilient.
|
||||
rc = 0
|
||||
while rc == 0:
|
||||
#q.q("waiting for input")
|
||||
data = sys.stdin.readline()
|
||||
if data == '':
|
||||
#q.q("data was empty, aborting")
|
||||
break
|
||||
if data.strip() == '':
|
||||
#q.q("data was empty line, skipping")
|
||||
continue
|
||||
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
attempts = 1
|
||||
while True:
|
||||
try:
|
||||
sf.connect(sf_path)
|
||||
break
|
||||
except socket.error:
|
||||
# FIXME: better error handling/logging/message here
|
||||
# FIXME: make # of retries configurable?
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
if attempts > 10:
|
||||
sys.stderr.write("failed to connect to the host, connection timeout\n")
|
||||
sys.exit(255)
|
||||
|
||||
#q.q("sending data to pipe")
|
||||
send_data(sf, to_bytes(data.strip()))
|
||||
#q.q("getting data back")
|
||||
rc = int(recv_data(sf), 10)
|
||||
#q.q(rc)
|
||||
stdout = recv_data(sf)
|
||||
#q.q(stdout)
|
||||
stderr = recv_data(sf)
|
||||
#q.q(stderr)
|
||||
sys.stdout.write(to_native(stdout))
|
||||
sys.stderr.write(to_native(stderr))
|
||||
#sys.stdout.flush()
|
||||
#sys.stderr.flush()
|
||||
|
||||
sf.close()
|
||||
break
|
||||
sys.exit(rc)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
27
bin/test-output
Executable file
27
bin/test-output
Executable file
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import cPickle
|
||||
import sys
|
||||
from cStringIO import StringIO
|
||||
from ansible.playbook.play_context import PlayContext
|
||||
|
||||
p = PlayContext()
|
||||
p.connection = 'paramiko_ssh'
|
||||
p.remote_addr = '192.168.122.100'
|
||||
p.port = 22
|
||||
p.remote_user = 'root'
|
||||
p.password = ''
|
||||
|
||||
src = StringIO()
|
||||
cPickle.dump(p.serialize(), src)
|
||||
sys.stdout.write(src.getvalue())
|
||||
sys.stdout.write('\n#END_INIT#\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
while True:
|
||||
data = sys.stdin.readline()
|
||||
if data == '':
|
||||
break
|
||||
if data.strip() != '':
|
||||
sys.stdout.write(data)
|
||||
sys.stdout.flush()
|
|
@ -298,6 +298,7 @@ DISPLAY_ARGS_TO_STDOUT = get_config(p, DEFAULTS, 'display_args_to_stdout
|
|||
MAX_FILE_SIZE_FOR_DIFF = get_config(p, DEFAULTS, 'max_diff_size', 'ANSIBLE_MAX_DIFF_SIZE', 1024*1024, value_type='integer')
|
||||
|
||||
# CONNECTION RELATED
|
||||
USE_PERSISTENT_CONNECTIONS = get_config(p, DEFAULTS, 'use_persistent_connections', 'ANSIBLE_USE_PERSISTENT_CONNECTIONS', False, value_type='boolean')
|
||||
ANSIBLE_SSH_ARGS = get_config(p, 'ssh_connection', 'ssh_args', 'ANSIBLE_SSH_ARGS', '-C -o ControlMaster=auto -o ControlPersist=60s')
|
||||
ANSIBLE_SSH_CONTROL_PATH = get_config(p, 'ssh_connection', 'control_path', 'ANSIBLE_SSH_CONTROL_PATH', u"%(directory)s/ansible-ssh-%%h-%%p-%%r")
|
||||
ANSIBLE_SSH_CONTROL_PATH_DIR = get_config(p, 'ssh_connection', 'control_path_dir', 'ANSIBLE_SSH_CONTROL_PATH_DIR', u'~/.ansible/cp')
|
||||
|
@ -306,6 +307,7 @@ ANSIBLE_SSH_RETRIES = get_config(p, 'ssh_connection', 'retries', 'ANS
|
|||
ANSIBLE_SSH_EXECUTABLE = get_config(p, 'ssh_connection', 'ssh_executable', 'ANSIBLE_SSH_EXECUTABLE', 'ssh')
|
||||
PARAMIKO_RECORD_HOST_KEYS = get_config(p, 'paramiko_connection', 'record_host_keys', 'ANSIBLE_PARAMIKO_RECORD_HOST_KEYS', True, value_type='boolean')
|
||||
PARAMIKO_PROXY_COMMAND = get_config(p, 'paramiko_connection', 'proxy_command', 'ANSIBLE_PARAMIKO_PROXY_COMMAND', None)
|
||||
PERSISTENT_CONNECT_TIMEOUT = get_config(p, 'persistent_connection', 'connect_timeout', 'ANSIBLE_PERSISTENT_CONNECT_TIMEOUT', 30, value_type='integer')
|
||||
|
||||
# obsolete -- will be formally removed
|
||||
ZEROMQ_PORT = get_config(p, 'fireball_connection', 'zeromq_port', 'ANSIBLE_ZEROMQ_PORT', 5099, value_type='integer')
|
||||
|
|
|
@ -696,7 +696,19 @@ class TaskExecutor:
|
|||
if not check_for_controlpersist(self._play_context.ssh_executable):
|
||||
conn_type = "paramiko"
|
||||
|
||||
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
|
||||
# if using persistent connections (or the action has set the FORCE_PERSISTENT_CONNECTION
|
||||
# attribute to True), then we use the persistent connection plugion. Otherwise load the
|
||||
# requested connection plugin
|
||||
if C.USE_PERSISTENT_CONNECTIONS or getattr(self, 'FORCE_PERSISTENT_CONNECTION', False) or conn_type == 'persistent':
|
||||
# if someone did `connection: persistent`, default it to using a
|
||||
# persistent paramiko connection to avoid problems
|
||||
if conn_type == 'persistent':
|
||||
self._play_context.connection = 'paramiko'
|
||||
|
||||
connection = self._shared_loader_obj.connection_loader.get('persistent', self._play_context, self._new_stdin)
|
||||
else:
|
||||
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
|
||||
|
||||
if not connection:
|
||||
raise AnsibleError("the connection plugin '%s' was not found" % conn_type)
|
||||
|
||||
|
|
|
@ -145,9 +145,9 @@ class Connection(ConnectionBase):
|
|||
proxy_command = None
|
||||
# Parse ansible_ssh_common_args, specifically looking for ProxyCommand
|
||||
ssh_args = [
|
||||
getattr(self._play_context, 'ssh_extra_args', ''),
|
||||
getattr(self._play_context, 'ssh_common_args', ''),
|
||||
getattr(self._play_context, 'ssh_args', ''),
|
||||
getattr(self._play_context, 'ssh_extra_args', '') or '',
|
||||
getattr(self._play_context, 'ssh_common_args', '') or '',
|
||||
getattr(self._play_context, 'ssh_args', '') or '',
|
||||
]
|
||||
if ssh_args is not None:
|
||||
args = self._split_ssh_args(' '.join(ssh_args))
|
||||
|
|
81
lib/ansible/plugins/connection/persistent.py
Normal file
81
lib/ansible/plugins/connection/persistent.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
# (c) 2016 Red Hat Inc.
|
||||
#
|
||||
# This file is part of Ansible
|
||||
#
|
||||
# Ansible is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Ansible is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from __future__ import (absolute_import, division, print_function)
|
||||
__metaclass__ = type
|
||||
|
||||
import os
|
||||
import pty
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from ansible.module_utils._text import to_bytes
|
||||
from ansible.module_utils.six.moves import cPickle, StringIO
|
||||
from ansible.plugins.connection import ConnectionBase
|
||||
|
||||
try:
|
||||
from __main__ import display
|
||||
except ImportError:
|
||||
from ansible.utils.display import Display
|
||||
display = Display()
|
||||
|
||||
|
||||
class Connection(ConnectionBase):
|
||||
''' Local based connections '''
|
||||
|
||||
transport = 'persistent'
|
||||
has_pipelining = False
|
||||
|
||||
def _connect(self):
|
||||
|
||||
self._connected = True
|
||||
return self
|
||||
|
||||
def _do_it(self, action):
|
||||
|
||||
master, slave = pty.openpty()
|
||||
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdin = os.fdopen(master, 'wb', 0)
|
||||
os.close(slave)
|
||||
|
||||
src = StringIO()
|
||||
cPickle.dump(self._play_context.serialize(), src)
|
||||
stdin.write(src.getvalue())
|
||||
src.close()
|
||||
|
||||
stdin.write(b'\n#END_INIT#\n')
|
||||
stdin.write(to_bytes(action))
|
||||
stdin.write(b'\n\n')
|
||||
stdin.close()
|
||||
(stdout, stderr) = p.communicate()
|
||||
|
||||
return (p.returncode, stdout, stderr)
|
||||
|
||||
def exec_command(self, cmd, in_data=None, sudoable=True):
|
||||
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
|
||||
return self._do_it('EXEC: ' + cmd)
|
||||
|
||||
def put_file(self, in_path, out_path):
|
||||
super(Connection, self).put_file(in_path, out_path)
|
||||
self._do_it('PUT: %s %s' % (in_path, out_path))
|
||||
|
||||
def fetch_file(self, in_path, out_path):
|
||||
super(Connection, self).fetch_file(in_path, out_path)
|
||||
self._do_it('FETCH: %s %s' % (in_path, out_path))
|
||||
|
||||
def close(self):
|
||||
self._connected = False
|
Loading…
Reference in a new issue