ansible/library/utilities/accelerate

507 lines
18 KiB
Text
Raw Normal View History

2013-08-11 07:41:18 +02:00
#!/usr/bin/python
# -*- coding: utf-8 -*-
2014-01-28 17:20:36 +01:00
# (c) 2013, James Cammarata <jcammarata@ansible.com>
2013-08-11 07:41:18 +02:00
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
DOCUMENTATION = '''
---
module: accelerate
short_description: Enable accelerated mode on remote node
2013-08-11 07:41:18 +02:00
description:
- This modules launches an ephemeral I(accelerate) daemon on the remote node which
2013-08-11 07:41:18 +02:00
Ansible can use to communicate with nodes at high speed.
- The daemon listens on a configurable port for a configurable amount of time.
- Fireball mode is AES encrypted
version_added: "1.3"
options:
port:
description:
- TCP port for the socket connection
required: false
default: 5099
aliases: []
timeout:
description:
- The number of seconds the socket will wait for data. If none is received when the timeout value is reached, the connection will be closed.
required: false
default: 300
aliases: []
2013-08-11 07:41:18 +02:00
minutes:
description:
- The I(accelerate) listener daemon is started on nodes and will stay around for
2013-08-11 07:41:18 +02:00
this number of minutes before turning itself off.
required: false
default: 30
ipv6:
description:
- The listener daemon on the remote host will bind to the ipv6 localhost socket
if this parameter is set to true.
required: false
default: false
2013-08-11 07:41:18 +02:00
notes:
- See the advanced playbooks chapter for more about using accelerated mode.
requirements: [ "python-keyczar" ]
2013-08-11 07:41:18 +02:00
author: James Cammarata
'''
EXAMPLES = '''
# To use accelerate mode, simply add "accelerate: true" to your play. The initial
2013-08-11 07:41:18 +02:00
# key exchange and starting up of the daemon will occur over SSH, but all commands and
# subsequent actions will be conducted over the raw socket connection using AES encryption
- hosts: devservers
accelerate: true
2013-08-11 07:41:18 +02:00
tasks:
- command: /usr/bin/anything
'''
import base64
import getpass
import json
2013-08-11 07:41:18 +02:00
import os
import os.path
import pwd
import signal
2013-08-11 07:41:18 +02:00
import socket
import struct
import sys
2013-08-11 07:41:18 +02:00
import syslog
import tempfile
2013-08-11 07:41:18 +02:00
import time
import traceback
import SocketServer
from datetime import datetime
from threading import Thread
2013-08-11 07:41:18 +02:00
syslog.openlog('ansible-%s' % os.path.basename(__file__))
PIDFILE = os.path.expanduser("~/.accelerate.pid")
2013-08-11 07:41:18 +02:00
# the chunk size to read and send, assuming mtu 1500 and
# leaving room for base64 (+33%) encoding and header (100 bytes)
# 4 * (975/3) + 100 = 1400
# which leaves room for the TCP/IP header
CHUNK_SIZE=10240
# FIXME: this all should be moved to module_common, as it's
# pretty much a copy from the callbacks/util code
DEBUG_LEVEL=0
def log(msg, cap=0):
global DEBUG_LEVEL
if DEBUG_LEVEL >= cap:
syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg)
def vv(msg):
log(msg, cap=2)
def vvv(msg):
log(msg, cap=3)
def vvvv(msg):
log(msg, cap=4)
2013-08-11 07:41:18 +02:00
if os.path.exists(PIDFILE):
try:
data = int(open(PIDFILE).read())
try:
os.kill(data, signal.SIGKILL)
except OSError:
pass
except ValueError:
pass
os.unlink(PIDFILE)
HAS_KEYCZAR = False
try:
from keyczar.keys import AesKey
HAS_KEYCZAR = True
except ImportError:
pass
# NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move
# this into utils.module_common and probably should anyway
def daemonize_self(module, password, port, minutes):
# daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012
try:
pid = os.fork()
if pid > 0:
vvv("exiting pid %s" % pid)
2013-08-11 07:41:18 +02:00
# exit first parent
module.exit_json(msg="daemonized accelerate on port %s for %s minutes with pid %s" % (port, minutes, str(pid)))
2013-08-11 07:41:18 +02:00
except OSError, e:
log("fork #1 failed: %d (%s)" % (e.errno, e.strerror))
sys.exit(1)
# decouple from parent environment
os.chdir("/")
os.setsid()
os.umask(022)
# do second fork
try:
pid = os.fork()
if pid > 0:
log("daemon pid %s, writing %s" % (pid, PIDFILE))
pid_file = open(PIDFILE, "w")
pid_file.write("%s" % pid)
pid_file.close()
vvv("pidfile written")
2013-08-11 07:41:18 +02:00
sys.exit(0)
except OSError, e:
log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
sys.exit(1)
dev_null = file('/dev/null','rw')
os.dup2(dev_null.fileno(), sys.stdin.fileno())
os.dup2(dev_null.fileno(), sys.stdout.fileno())
os.dup2(dev_null.fileno(), sys.stderr.fileno())
log("daemonizing successful")
class ThreadWithReturnValue(Thread):
2013-10-19 19:02:57 +02:00
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None):
Thread.__init__(self, group, target, name, args, kwargs, Verbose)
self._return = None
def run(self):
if self._Thread__target is not None:
self._return = self._Thread__target(*self._Thread__args,
**self._Thread__kwargs)
2013-10-19 19:02:57 +02:00
def join(self,timeout=None):
Thread.join(self, timeout=timeout)
return self._return
2013-08-11 07:41:18 +02:00
class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
2013-08-11 07:41:18 +02:00
self.module = module
self.key = AesKey.Read(password)
self.allow_reuse_address = True
self.timeout = timeout
2013-08-11 07:41:18 +02:00
SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
class ThreadedTCPV6Server(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password, timeout):
self.module = module
self.address_family = socket.AF_INET6
self.key = AesKey.Read(password)
self.allow_reuse_address = True
self.timeout = timeout
SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
2013-08-11 07:41:18 +02:00
class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def send_data(self, data):
packed_len = struct.pack('!Q', len(data))
2013-08-11 07:41:18 +02:00
return self.request.sendall(packed_len + data)
def recv_data(self):
header_len = 8 # size of a packed unsigned long long
data = ""
vvvv("in recv_data(), waiting for the header")
2013-08-11 07:41:18 +02:00
while len(data) < header_len:
d = self.request.recv(header_len - len(data))
2013-08-11 07:41:18 +02:00
if not d:
vvv("received nothing, bailing out")
2013-08-11 07:41:18 +02:00
return None
data += d
vvvv("in recv_data(), got the header, unpacking")
data_len = struct.unpack('!Q',data[:header_len])[0]
2013-08-11 07:41:18 +02:00
data = data[header_len:]
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
2013-08-11 07:41:18 +02:00
while len(data) < data_len:
d = self.request.recv(data_len - len(data))
if not d:
vvv("received nothing, bailing out")
return None
data += d
vvvv("data received so far (expecting %d): %d" % (data_len,len(data)))
vvvv("received all of the data, returning")
2013-08-11 07:41:18 +02:00
return data
def handle(self):
try:
while True:
vvvv("waiting for data")
data = self.recv_data()
if not data:
vvvv("received nothing back from recv_data(), breaking out")
break
try:
vvvv("got data, decrypting")
data = self.server.key.Decrypt(data)
vvvv("decryption done")
except:
vv("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1))
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
return
vvvv("loading json from the data")
data = json.loads(data)
mode = data['mode']
response = {}
last_pong = datetime.now()
if mode == 'command':
vvvv("received a command request, running it")
twrv = ThreadWithReturnValue(target=self.command, args=(data,))
twrv.start()
response = None
while twrv.is_alive():
if (datetime.now() - last_pong).seconds >= 15:
last_pong = datetime.now()
vvvv("command still running, sending keepalive packet")
data2 = json.dumps(dict(pong=True))
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
time.sleep(0.1)
response = twrv._return
vvvv("thread is done, response from join was %s" % response)
elif mode == 'put':
vvvv("received a put request, putting it")
response = self.put(data)
elif mode == 'fetch':
vvvv("received a fetch request, getting it")
response = self.fetch(data)
elif mode == 'validate_user':
vvvv("received a request to validate the user id")
response = self.validate_user(data)
vvvv("response result is %s" % str(response))
data2 = json.dumps(response)
2013-08-11 07:41:18 +02:00
data2 = self.server.key.Encrypt(data2)
vvvv("sending the response back to the controller")
self.send_data(data2)
vvvv("done sending the response")
if mode == 'validate_user' and response.get('rc') == 1:
vvvv("detected a uid mismatch, shutting down")
self.server.shutdown()
except:
tb = traceback.format_exc()
log("encountered an unhandled exception in the handle() function")
log("error was:\n%s" % tb)
data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function"))
2013-08-11 07:41:18 +02:00
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
def validate_user(self, data):
if 'username' not in data:
return dict(failed=True, msg='No username specified')
vvvv("validating we're running as %s" % data['username'])
# get the current uid
c_uid = os.getuid()
try:
# the target uid
t_uid = pwd.getpwnam(data['username']).pw_uid
except:
vvvv("could not find user %s" % data['username'])
return dict(failed=True, msg='could not find user %s' % data['username'])
# and return rc=0 for success, rc=1 for failure
if c_uid == t_uid:
return dict(rc=0)
else:
return dict(rc=1)
2013-08-11 07:41:18 +02:00
def command(self, data):
if 'cmd' not in data:
return dict(failed=True, msg='internal error: cmd is required')
if 'tmp_path' not in data:
return dict(failed=True, msg='internal error: tmp_path is required')
vvvv("executing: %s" % data['cmd'])
use_unsafe_shell = False
executable = data.get('executable')
if executable:
use_unsafe_shell = True
rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=executable, use_unsafe_shell=use_unsafe_shell)
2013-08-11 07:41:18 +02:00
if stdout is None:
stdout = ''
if stderr is None:
stderr = ''
vvvv("got stdout: %s" % stdout)
vvvv("got stderr: %s" % stderr)
2013-08-11 07:41:18 +02:00
return dict(rc=rc, stdout=stdout, stderr=stderr)
def fetch(self, data):
if 'in_path' not in data:
return dict(failed=True, msg='internal error: in_path is required')
try:
fd = file(data['in_path'], 'rb')
fstat = os.stat(data['in_path'])
vvv("FETCH file is %d bytes" % fstat.st_size)
while fd.tell() < fstat.st_size:
data = fd.read(CHUNK_SIZE)
last = False
if fd.tell() >= fstat.st_size:
last = True
data = dict(data=base64.b64encode(data), last=last)
data = json.dumps(data)
data = self.server.key.Encrypt(data)
if self.send_data(data):
return dict(failed=True, stderr="failed to send data")
response = self.recv_data()
if not response:
log("failed to get a response, aborting")
return dict(failed=True, stderr="Failed to get a response from %s" % self.host)
response = self.server.key.Decrypt(response)
response = json.loads(response)
if response.get('failed',False):
log("got a failed response from the master")
return dict(failed=True, stderr="Master reported failure, aborting transfer")
except Exception, e:
fd.close()
tb = traceback.format_exc()
log("failed to fetch the file: %s" % tb)
return dict(failed=True, stderr="Could not fetch the file: %s" % str(e))
2013-08-11 07:41:18 +02:00
fd.close()
return dict()
2013-08-11 07:41:18 +02:00
def put(self, data):
if 'data' not in data:
return dict(failed=True, msg='internal error: data is required')
if 'out_path' not in data:
return dict(failed=True, msg='internal error: out_path is required')
final_path = None
if 'user' in data and data.get('user') != getpass.getuser():
vv("the target user doesn't match this user, we'll move the file into place via sudo")
tmp_path = os.path.expanduser('~/.ansible/tmp/')
if not os.path.exists(tmp_path):
try:
os.makedirs(tmp_path, 0700)
except:
return dict(failed=True, msg='could not create a temporary directory at %s' % tmp_path)
(fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=tmp_path)
out_fd = os.fdopen(fd, 'w', 0)
final_path = data['out_path']
else:
out_path = data['out_path']
out_fd = open(out_path, 'w')
try:
bytes=0
while True:
out = base64.b64decode(data['data'])
bytes += len(out)
out_fd.write(out)
response = json.dumps(dict())
response = self.server.key.Encrypt(response)
self.send_data(response)
if data['last']:
break
data = self.recv_data()
if not data:
raise ""
data = self.server.key.Decrypt(data)
data = json.loads(data)
except:
out_fd.close()
tb = traceback.format_exc()
log("failed to put the file: %s" % tb)
return dict(failed=True, stdout="Could not write the file")
vvvv("wrote %d bytes" % bytes)
out_fd.close()
if final_path:
vvv("moving %s to %s" % (out_path, final_path))
self.server.module.atomic_move(out_path, final_path)
2013-08-11 07:41:18 +02:00
return dict()
def daemonize(module, password, port, timeout, minutes, ipv6):
2013-08-11 07:41:18 +02:00
try:
daemonize_self(module, password, port, minutes)
def catcher(signum, _):
module.exit_json(msg='timer expired')
signal.signal(signal.SIGALRM, catcher)
signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
tries = 5
while tries > 0:
try:
if ipv6:
server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout)
else:
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout)
server.allow_reuse_address = True
break
except:
vv("Failed to create the TCP server (tries left = %d)" % tries)
tries -= 1
time.sleep(0.2)
2013-08-11 07:41:18 +02:00
if tries == 0:
vv("Maximum number of attempts to create the TCP server reached, bailing out")
raise Exception("max # of attempts to serve reached")
vv("serving!")
server.serve_forever(poll_interval=0.1)
2013-08-11 07:41:18 +02:00
except Exception, e:
tb = traceback.format_exc()
log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb))
2013-08-11 07:41:18 +02:00
sys.exit(0)
def main():
global DEBUG_LEVEL
2013-08-11 07:41:18 +02:00
module = AnsibleModule(
argument_spec = dict(
port=dict(required=False, default=5099),
ipv6=dict(required=False, default=False, type='bool'),
timeout=dict(required=False, default=300),
2013-08-11 07:41:18 +02:00
password=dict(required=True),
minutes=dict(required=False, default=30),
debug=dict(required=False, default=0, type='int')
),
supports_check_mode=True
2013-08-11 07:41:18 +02:00
)
password = base64.b64decode(module.params['password'])
port = int(module.params['port'])
timeout = int(module.params['timeout'])
2013-08-11 07:41:18 +02:00
minutes = int(module.params['minutes'])
debug = int(module.params['debug'])
ipv6 = module.params['ipv6']
2013-08-11 07:41:18 +02:00
if not HAS_KEYCZAR:
module.fail_json(msg="keyczar is not installed (on the remote side)")
2013-08-11 07:41:18 +02:00
DEBUG_LEVEL=debug
2013-08-11 07:41:18 +02:00
daemonize(module, password, port, timeout, minutes, ipv6)
2013-08-11 07:41:18 +02:00
2013-12-02 21:13:49 +01:00
# import module snippets
from ansible.module_utils.basic import *
2013-08-11 07:41:18 +02:00
main()