Merge branch 'accelerate_improvements' into devel

Conflicts:
	library/utilities/accelerate
This commit is contained in:
James Cammarata 2013-10-01 21:22:17 -05:00
commit 912e3a7b0b
4 changed files with 166 additions and 69 deletions

View file

@ -136,6 +136,8 @@ ANSIBLE_SSH_CONTROL_PATH = get_config(p, 'ssh_connection', 'control_path',
PARAMIKO_RECORD_HOST_KEYS = get_config(p, 'paramiko_connection', 'record_host_keys', 'ANSIBLE_PARAMIKO_RECORD_HOST_KEYS', True, boolean=True) PARAMIKO_RECORD_HOST_KEYS = get_config(p, 'paramiko_connection', 'record_host_keys', 'ANSIBLE_PARAMIKO_RECORD_HOST_KEYS', True, boolean=True)
ZEROMQ_PORT = get_config(p, 'fireball_connection', 'zeromq_port', 'ANSIBLE_ZEROMQ_PORT', 5099, integer=True) ZEROMQ_PORT = get_config(p, 'fireball_connection', 'zeromq_port', 'ANSIBLE_ZEROMQ_PORT', 5099, integer=True)
ACCELERATE_PORT = get_config(p, 'accelerate', 'accelerate_port', 'ACCELERATE_PORT', 5099, integer=True) ACCELERATE_PORT = get_config(p, 'accelerate', 'accelerate_port', 'ACCELERATE_PORT', 5099, integer=True)
ACCELERATE_TIMEOUT = int(get_config(p, 'accelerate', 'accelerate_timeout', 'ACCELERATE_TIMEOUT', 30))
ACCELERATE_CONNECT_TIMEOUT = float(get_config(p, 'accelerate', 'accelerate_connect_timeout', 'ACCELERATE_CONNECT_TIMEOUT', 1.0))
DEFAULT_UNDEFINED_VAR_BEHAVIOR = get_config(p, DEFAULTS, 'error_on_undefined_vars', 'ANSIBLE_ERROR_ON_UNDEFINED_VARS', True, boolean=True) DEFAULT_UNDEFINED_VAR_BEHAVIOR = get_config(p, DEFAULTS, 'error_on_undefined_vars', 'ANSIBLE_ERROR_ON_UNDEFINED_VARS', True, boolean=True)
HOST_KEY_CHECKING = get_config(p, DEFAULTS, 'host_key_checking', 'ANSIBLE_HOST_KEY_CHECKING', True, boolean=True) HOST_KEY_CHECKING = get_config(p, DEFAULTS, 'host_key_checking', 'ANSIBLE_HOST_KEY_CHECKING', True, boolean=True)

View file

@ -454,7 +454,7 @@ class PlayBook(object):
setup_cache=self.SETUP_CACHE, callbacks=self.runner_callbacks, sudo=play.sudo, sudo_user=play.sudo_user, setup_cache=self.SETUP_CACHE, callbacks=self.runner_callbacks, sudo=play.sudo, sudo_user=play.sudo_user,
transport=play.transport, sudo_pass=self.sudo_pass, is_playbook=True, module_vars=play.vars, transport=play.transport, sudo_pass=self.sudo_pass, is_playbook=True, module_vars=play.vars,
default_vars=play.default_vars, check=self.check, diff=self.diff, default_vars=play.default_vars, check=self.check, diff=self.diff,
accelerate=play.accelerate, accelerate_port=play.accelerate_port accelerate=play.accelerate, accelerate_port=play.accelerate_port,
).run() ).run()
self.stats.compute(setup_results, setup=True) self.stats.compute(setup_results, setup=True)

View file

@ -21,7 +21,7 @@ import base64
import socket import socket
import struct import struct
import time import time
from ansible.callbacks import vvv from ansible.callbacks import vvv, vvvv
from ansible.runner.connection_plugins.ssh import Connection as SSHConnection from ansible.runner.connection_plugins.ssh import Connection as SSHConnection
from ansible.runner.connection_plugins.paramiko_ssh import Connection as ParamikoConnection from ansible.runner.connection_plugins.paramiko_ssh import Connection as ParamikoConnection
from ansible import utils from ansible import utils
@ -84,12 +84,13 @@ class Connection(object):
utils.AES_KEYS = self.runner.aes_keys utils.AES_KEYS = self.runner.aes_keys
def _execute_accelerate_module(self): def _execute_accelerate_module(self):
args = "password=%s port=%s" % (base64.b64encode(self.key.__str__()), str(self.accport)) args = "password=%s port=%s debug=%d" % (base64.b64encode(self.key.__str__()), str(self.accport), int(utils.VERBOSITY))
inject = dict(password=self.key) inject = dict(password=self.key)
if self.runner.accelerate_inventory_host: if self.runner.accelerate_inventory_host:
inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.runner.accelerate_inventory_host)) inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.runner.accelerate_inventory_host))
else: else:
inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.host)) inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.host))
vvvv("attempting to start up the accelerate daemon...")
self.ssh.connect() self.ssh.connect()
tmp_path = self.runner._make_tmp_path(self.ssh) tmp_path = self.runner._make_tmp_path(self.ssh)
return self.runner._execute_module(self.ssh, tmp_path, 'accelerate', args, inject=inject) return self.runner._execute_module(self.ssh, tmp_path, 'accelerate', args, inject=inject)
@ -99,20 +100,22 @@ class Connection(object):
try: try:
if not self.is_connected: if not self.is_connected:
# TODO: make the timeout and retries configurable?
tries = 3 tries = 3
self.conn = socket.socket() self.conn = socket.socket()
self.conn.settimeout(300.0) self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT)
vvvv("attempting connection to %s via the accelerated port %d" % (self.host,self.accport))
while tries > 0: while tries > 0:
try: try:
self.conn.connect((self.host,self.accport)) self.conn.connect((self.host,self.accport))
break break
except: except:
vvvv("failed, retrying...")
time.sleep(0.1) time.sleep(0.1)
tries -= 1 tries -= 1
if tries == 0: if tries == 0:
vvv("Could not connect via the accelerated connection, exceeded # of tries") vvv("Could not connect via the accelerated connection, exceeded # of tries")
raise errors.AnsibleError("Failed to connect") raise errors.AnsibleError("Failed to connect")
self.conn.settimeout(constants.ACCELERATE_TIMEOUT)
except: except:
if allow_ssh: if allow_ssh:
vvv("Falling back to ssh to startup accelerated mode") vvv("Falling back to ssh to startup accelerated mode")
@ -133,18 +136,24 @@ class Connection(object):
header_len = 8 # size of a packed unsigned long long header_len = 8 # size of a packed unsigned long long
data = b"" data = b""
try: try:
vvvv("%s: in recv_data(), waiting for the header" % self.host)
while len(data) < header_len: while len(data) < header_len:
d = self.conn.recv(1024) d = self.conn.recv(header_len - len(data))
if not d: if not d:
vvvv("%s: received nothing, bailing out" % self.host)
return None return None
data += d data += d
vvvv("%s: got the header, unpacking" % self.host)
data_len = struct.unpack('Q',data[:header_len])[0] data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:] data = data[header_len:]
vvvv("%s: data received so far (expecting %d): %d" % (self.host,data_len,len(data)))
while len(data) < data_len: while len(data) < data_len:
d = self.conn.recv(1024) d = self.conn.recv(data_len - len(data))
if not d: if not d:
vvvv("%s: received nothing, bailing out" % self.host)
return None return None
data += d data += d
vvvv("%s: received all of the data, returning" % self.host)
return data return data
except socket.timeout: except socket.timeout:
raise errors.AnsibleError("timed out while waiting to receive data") raise errors.AnsibleError("timed out while waiting to receive data")
@ -171,11 +180,22 @@ class Connection(object):
if self.send_data(data): if self.send_data(data):
raise errors.AnsibleError("Failed to send command to %s" % self.host) raise errors.AnsibleError("Failed to send command to %s" % self.host)
response = self.recv_data() while True:
if not response: # we loop here while waiting for the response, because a
raise errors.AnsibleError("Failed to get a response from %s" % self.host) # long running command may cause us to receive keepalive packets
response = utils.decrypt(self.key, response) # ({"pong":"true"}) rather than the response we want.
response = utils.parse_json(response) response = self.recv_data()
if not response:
raise errors.AnsibleError("Failed to get a response from %s" % self.host)
response = utils.decrypt(self.key, response)
response = utils.parse_json(response)
if "pong" in response:
# it's a keepalive, go back to waiting
vvvv("%s: received a keepalive packet" % self.host)
continue
else:
vvvv("%s: received the response" % self.host)
break
return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr','')) return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr',''))

View file

@ -35,6 +35,12 @@ options:
required: false required: false
default: 5099 default: 5099
aliases: [] aliases: []
timeout:
description:
- The number of seconds the socket will wait for data. If none is received when the timeout value is reached, the connection will be closed.
required: false
default: 300
aliases: []
minutes: minutes:
description: description:
- The I(accelerate) listener daemon is started on nodes and will stay around for - The I(accelerate) listener daemon is started on nodes and will stay around for
@ -58,24 +64,25 @@ EXAMPLES = '''
- command: /usr/bin/anything - command: /usr/bin/anything
''' '''
import os
import os.path
import tempfile
import sys
import shutil
import socket
import struct
import time
import base64 import base64
import getpass import getpass
import os
import os.path
import shutil
import signal
import socket
import struct
import sys
import syslog import syslog
import signal import tempfile
import time import time
import signal
import traceback import traceback
import SocketServer import SocketServer
from datetime import datetime
from threading import Thread
syslog.openlog('ansible-%s' % os.path.basename(__file__)) syslog.openlog('ansible-%s' % os.path.basename(__file__))
PIDFILE = os.path.expanduser("~/.accelerate.pid") PIDFILE = os.path.expanduser("~/.accelerate.pid")
@ -85,8 +92,22 @@ PIDFILE = os.path.expanduser("~/.accelerate.pid")
# which leaves room for the TCP/IP header # which leaves room for the TCP/IP header
CHUNK_SIZE=10240 CHUNK_SIZE=10240
def log(msg): # FIXME: this all should be moved to module_common, as it's
syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg) # pretty much a copy from the callbacks/util code
DEBUG_LEVEL=0
def log(msg, cap=0):
global DEBUG_LEVEL
if cap >= DEBUG_LEVEL:
syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg)
def vv(msg):
log(msg, cap=2)
def vvv(msg):
log(msg, cap=3)
def vvvv(msg):
log(msg, cap=4)
if os.path.exists(PIDFILE): if os.path.exists(PIDFILE):
try: try:
@ -114,7 +135,7 @@ def daemonize_self(module, password, port, minutes):
try: try:
pid = os.fork() pid = os.fork()
if pid > 0: if pid > 0:
log("exiting pid %s" % pid) vvv("exiting pid %s" % pid)
# exit first parent # exit first parent
module.exit_json(msg="daemonized accelerate on port %s for %s minutes" % (port, minutes)) module.exit_json(msg="daemonized accelerate on port %s for %s minutes" % (port, minutes))
except OSError, e: except OSError, e:
@ -134,7 +155,7 @@ def daemonize_self(module, password, port, minutes):
pid_file = open(PIDFILE, "w") pid_file = open(PIDFILE, "w")
pid_file.write("%s" % pid) pid_file.write("%s" % pid)
pid_file.close() pid_file.close()
log("pidfile written") vvv("pidfile written")
sys.exit(0) sys.exit(0)
except OSError, e: except OSError, e:
log("fork #2 failed: %d (%s)" % (e.errno, e.strerror)) log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
@ -146,12 +167,25 @@ def daemonize_self(module, password, port, minutes):
os.dup2(dev_null.fileno(), sys.stderr.fileno()) os.dup2(dev_null.fileno(), sys.stderr.fileno())
log("daemonizing successful") log("daemonizing successful")
class ThreadWithReturnValue(Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None):
Thread.__init__(self, group, target, name, args, kwargs, Verbose)
self._return = None
def run(self):
if self._Thread__target is not None:
self._return = self._Thread__target(*self._Thread__args,
**self._Thread__kwargs)
def join(self,timeout=None):
Thread.join(self, timeout=timeout)
return self._return
class ThreadedTCPServer(SocketServer.ThreadingTCPServer): class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password): def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
self.module = module self.module = module
self.key = AesKey.Read(password) self.key = AesKey.Read(password)
self.allow_reuse_address = True self.allow_reuse_address = True
self.timeout = None self.timeout = timeout
SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
@ -162,50 +196,85 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def recv_data(self): def recv_data(self):
header_len = 8 # size of a packed unsigned long long header_len = 8 # size of a packed unsigned long long
data = b"" data = b""
vvvv("in recv_data(), waiting for the header")
while len(data) < header_len: while len(data) < header_len:
d = self.request.recv(1024) d = self.request.recv(header_len - len(data))
if not d: if not d:
vvv("received nothing, bailing out")
return None return None
data += d data += d
vvvv("in recv_data(), got the header, unpacking")
data_len = struct.unpack('Q',data[:header_len])[0] data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:] data = data[header_len:]
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
while len(data) < data_len: while len(data) < data_len:
d = self.request.recv(1024) d = self.request.recv(data_len - len(data))
if not d: if not d:
vvv("received nothing, bailing out")
return None return None
data += d data += d
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
vvvv("received all of the data, returning")
return data return data
def handle(self): def handle(self):
while True: try:
#log("waiting for data") while True:
data = self.recv_data() vvvv("waiting for data")
if not data: data = self.recv_data()
break if not data:
try: vvvv("received nothing back from recv_data(), breaking out")
#log("got data, decrypting") break
data = self.server.key.Decrypt(data) try:
#log("decryption done") vvvv("got data, decrypting")
except: data = self.server.key.Decrypt(data)
log("bad decrypt, skipping...") vvvv("decryption done")
data2 = json.dumps(dict(rc=1)) except:
vv("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1))
data2 = self.server.key.Encrypt(data2)
send_data(client, data2)
return
vvvv("loading json from the data")
data = json.loads(data)
mode = data['mode']
response = {}
last_pong = datetime.now()
if mode == 'command':
vvvv("received a command request, running it")
twrv = ThreadWithReturnValue(target=self.command, args=(data,))
twrv.start()
response = None
while twrv.is_alive():
if (datetime.now() - last_pong).seconds >= 15:
last_pong = datetime.now()
vvvv("command still running, sending keepalive packet")
data2 = json.dumps(dict(pong=True))
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
time.sleep(0.1)
response = twrv._return
vvvv("thread is done, response from join was %s" % response)
elif mode == 'put':
vvvv("received a put request, putting it")
response = self.put(data)
elif mode == 'fetch':
vvvv("received a fetch request, getting it")
response = self.fetch(data)
vvvv("response result is %s" % str(response))
data2 = json.dumps(response)
data2 = self.server.key.Encrypt(data2) data2 = self.server.key.Encrypt(data2)
send_data(client, data2) vvvv("sending the response back to the controller")
return self.send_data(data2)
vvvv("done sending the response")
#log("loading json from the data") except:
data = json.loads(data) tb = traceback.format_exc()
log("encountered an unhandled exception in the handle() function")
mode = data['mode'] log("error was:\n%s" % tb)
response = {} data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function"))
if mode == 'command':
response = self.command(data)
elif mode == 'put':
response = self.put(data)
elif mode == 'fetch':
response = self.fetch(data)
data2 = json.dumps(response)
data2 = self.server.key.Encrypt(data2) data2 = self.server.key.Encrypt(data2)
self.send_data(data2) self.send_data(data2)
@ -217,14 +286,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
if 'executable' not in data: if 'executable' not in data:
return dict(failed=True, msg='internal error: executable is required') return dict(failed=True, msg='internal error: executable is required')
#log("executing: %s" % data['cmd']) vvvv("executing: %s" % data['cmd'])
rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True) rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True)
if stdout is None: if stdout is None:
stdout = '' stdout = ''
if stderr is None: if stderr is None:
stderr = '' stderr = ''
#log("got stdout: %s" % stdout) vvvv("got stdout: %s" % stdout)
#log("got stderr: %s" % stderr) vvvv("got stderr: %s" % stderr)
return dict(rc=rc, stdout=stdout, stderr=stderr) return dict(rc=rc, stdout=stdout, stderr=stderr)
@ -235,7 +304,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
try: try:
fd = file(data['in_path'], 'rb') fd = file(data['in_path'], 'rb')
fstat = os.stat(data['in_path']) fstat = os.stat(data['in_path'])
log("FETCH file is %d bytes" % fstat.st_size) vvv("FETCH file is %d bytes" % fstat.st_size)
while fd.tell() < fstat.st_size: while fd.tell() < fstat.st_size:
data = fd.read(CHUNK_SIZE) data = fd.read(CHUNK_SIZE)
last = False last = False
@ -276,7 +345,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
final_path = None final_path = None
final_user = None final_user = None
if 'user' in data and data.get('user') != getpass.getuser(): if 'user' in data and data.get('user') != getpass.getuser():
log("the target user doesn't match this user, we'll move the file into place via sudo") vv("the target user doesn't match this user, we'll move the file into place via sudo")
(fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/')) (fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/'))
out_fd = os.fdopen(fd, 'w', 0) out_fd = os.fdopen(fd, 'w', 0)
final_path = data['out_path'] final_path = data['out_path']
@ -306,15 +375,15 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
log("failed to put the file: %s" % tb) log("failed to put the file: %s" % tb)
return dict(failed=True, stdout="Could not write the file") return dict(failed=True, stdout="Could not write the file")
finally: finally:
#log("wrote %d bytes" % bytes) vvvv("wrote %d bytes" % bytes)
out_fd.close() out_fd.close()
if final_path: if final_path:
log("moving %s to %s" % (out_path, final_path)) vvv("moving %s to %s" % (out_path, final_path))
self.server.module.atomic_move(out_path, final_path) self.server.module.atomic_move(out_path, final_path)
return dict() return dict()
def daemonize(module, password, port, minutes): def daemonize(module, password, port, timeout, minutes):
try: try:
daemonize_self(module, password, port, minutes) daemonize_self(module, password, port, minutes)
@ -324,10 +393,10 @@ def daemonize(module, password, port, minutes):
signal.signal(signal.SIGALRM, catcher) signal.signal(signal.SIGALRM, catcher)
signal.setitimer(signal.ITIMER_REAL, 60 * minutes) signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password) server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
server.allow_reuse_address = True server.allow_reuse_address = True
log("serving!") vv("serving!")
server.serve_forever(poll_interval=1.0) server.serve_forever(poll_interval=1.0)
except Exception, e: except Exception, e:
tb = traceback.format_exc() tb = traceback.format_exc()
@ -335,24 +404,30 @@ def daemonize(module, password, port, minutes):
sys.exit(0) sys.exit(0)
def main(): def main():
global DEBUG_LEVEL
module = AnsibleModule( module = AnsibleModule(
argument_spec = dict( argument_spec = dict(
port=dict(required=False, default=5099), port=dict(required=False, default=5099),
timeout=dict(required=False, default=300),
password=dict(required=True), password=dict(required=True),
minutes=dict(required=False, default=30), minutes=dict(required=False, default=30),
debug=dict(required=False, default=0, type='int')
), ),
supports_check_mode=True supports_check_mode=True
) )
password = base64.b64decode(module.params['password']) password = base64.b64decode(module.params['password'])
port = int(module.params['port']) port = int(module.params['port'])
timeout = int(module.params['timeout'])
minutes = int(module.params['minutes']) minutes = int(module.params['minutes'])
debug = int(module.params['debug'])
if not HAS_KEYCZAR: if not HAS_KEYCZAR:
module.fail_json(msg="keyczar is not installed") module.fail_json(msg="keyczar is not installed")
daemonize(module, password, port, minutes) DEBUG_LEVEL=debug
daemonize(module, password, port, timeout, minutes)
# this is magic, see lib/ansible/module_common.py # this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>> #<<INCLUDE_ANSIBLE_MODULE_COMMON>>