Fireball2 mode working!

This commit is contained in:
James Cammarata 2013-08-11 00:41:18 -05:00
parent acc5d09351
commit 521e14a3ad
6 changed files with 399 additions and 155 deletions

View file

@ -312,7 +312,7 @@ class PlayBook(object):
conditional=task.only_if, callbacks=self.runner_callbacks,
sudo=task.sudo, sudo_user=task.sudo_user,
transport=task.transport, sudo_pass=task.sudo_pass, is_playbook=True,
check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args,
check=self.check, diff=self.diff, environment=task.environment, complex_args=task.args, accelerate=task.play.accelerate,
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR
)

View file

@ -29,7 +29,7 @@ class Play(object):
__slots__ = [
'hosts', 'name', 'vars', 'vars_prompt', 'vars_files',
'handlers', 'remote_user', 'remote_port',
'handlers', 'remote_user', 'remote_port', 'accelerate',
'sudo', 'sudo_user', 'transport', 'playbook',
'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks',
'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct'
@ -39,7 +39,7 @@ class Play(object):
# and don't line up 1:1 with how they are stored
VALID_KEYS = [
'hosts', 'name', 'vars', 'vars_prompt', 'vars_files',
'tasks', 'handlers', 'user', 'port', 'include',
'tasks', 'handlers', 'user', 'port', 'include', 'accelerate',
'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial',
'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage'
]
@ -103,6 +103,7 @@ class Play(object):
self.gather_facts = ds.get('gather_facts', None)
self.remote_port = self.remote_port
self.any_errors_fatal = ds.get('any_errors_fatal', False)
self.accelerate = ds.get('accelerate', False)
self.max_fail_pct = int(ds.get('max_fail_percentage', 100))
load_vars = {}

View file

@ -138,7 +138,8 @@ class Runner(object):
diff=False, # whether to show diffs for template files that change
environment=None, # environment variables (as dict) to use inside the command
complex_args=None, # structured data in addition to module_args, must be a dict
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR # ex. False
error_on_undefined_vars=C.DEFAULT_UNDEFINED_VAR_BEHAVIOR, # ex. False
accelerate=False, # use fireball acceleration
):
# used to lock multiprocess inputs and outputs at various levels
@ -179,11 +180,16 @@ class Runner(object):
self.environment = environment
self.complex_args = complex_args
self.error_on_undefined_vars = error_on_undefined_vars
self.accelerate = accelerate
self.callbacks.runner = self
if self.accelerate:
# if we're using accelerated mode, force the local
# transport to fireball2
self.transport = "fireball2"
elif self.transport == 'smart':
# if the transport is 'smart' see if SSH can support ControlPersist if not use paramiko
# 'smart' is the default since 1.2.1/1.3
if self.transport == 'smart':
cmd = subprocess.Popen(['ssh','-o','ControlPersist'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(out, err) = cmd.communicate()
if "Bad configuration option" in err:

View file

@ -19,7 +19,9 @@ import json
import os
import base64
import socket
import struct
from ansible.callbacks import vvv
from ansible.runner.connection_plugins.ssh import Connection as SSHConnection
from ansible import utils
from ansible import errors
from ansible import constants
@ -27,32 +29,68 @@ from ansible import constants
class Connection(object):
''' raw socket accelerated connection '''
def __init__(self, runner, host, port, *args, **kwargs):
def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs):
self.ssh = SSHConnection(
runner=runner,
host=host,
port=port,
user=user,
password=password,
private_key_file=private_key_file
)
self.runner = runner
self.host = host
self.context = None
self.conn = None
self.key = utils.key_for_hostname(host)
self.fbport = constants.FIREBALL2_PORT
self.is_connected = False
# attempt to work around shared-memory funness
if getattr(self.runner, 'aes_keys', None):
utils.AES_KEYS = self.runner.aes_keys
self.host = host
self.context = None
self.conn = None
self.cipher = AES256Cipher()
def _execute_fb_module(self):
args = "password=%s" % base64.b64encode(self.key.__str__())
self.ssh.connect()
return self.runner._execute_module(self.ssh, "/root/.ansible/tmp", 'fireball2', args, inject={"password":self.key})
if port is None:
self.port = constants.FIREBALL2_PORT
else:
self.port = port
def connect(self):
def connect(self, allow_ssh=True):
''' activates the connection object '''
self.conn = socket.socket()
self.conn.connect((self.host,self.port))
if self.is_connected:
return self
try:
self.conn = socket.socket()
self.conn.connect((self.host,self.fbport))
except:
if allow_ssh:
print "Falling back to ssh to startup accelerated mode"
res = self._execute_fb_module()
return self.connect(allow_ssh=False)
else:
raise errors.AnsibleError("Failed to connect to %s:%s" % (self.host,self.fbport))
self.is_connected = True
return self
def send_data(self, data):
packed_len = struct.pack('Q',len(data))
return self.conn.sendall(packed_len + data)
def recv_data(self):
header_len = 8 # size of a packed unsigned long long
data = b""
while len(data) < header_len:
data += self.conn.recv(1024)
data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
data += self.conn.recv(1024)
return data
def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'):
''' run a command on the remote host '''
@ -65,12 +103,12 @@ class Connection(object):
executable=executable,
)
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnisbleError("Failed to send command to %s:%s" % (self.host,self.port))
response = self.conn.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, response)
response = utils.parse_json(response)
return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr',''))
@ -83,18 +121,18 @@ class Connection(object):
if not os.path.exists(in_path):
raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path)
data = base64.file(in_path).read()
data = file(in_path).read()
data = base64.b64encode(data)
data = dict(mode='put', data=data, out_path=out_path)
# TODO: support chunked file transfer
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnsibleError("failed to send the file to %s:%s" % (self.host,self.port))
response = self.conn.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, data)
response = utils.parse_json(response)
# no meaningful response needed for this
@ -105,12 +143,12 @@ class Connection(object):
data = dict(mode='fetch', in_path=in_path)
data = utils.jsonify(data)
data = self.cipher.encrypt(data)
if self.conn.sendall(data):
data = utils.encrypt(self.key, data)
if self.send_data(data):
raise errors.AnsibleError("failed to initiate the file fetch with %s:%s" % (self.host,self.port))
response = self.socket.recv(2048)
response = self.cipher.decrypt(response)
response = self.recv_data()
response = utils.decrypt(self.key, data)
response = utils.parse_json(response)
response = response['data']
response = base64.b64decode(response)

View file

@ -31,7 +31,6 @@ import ansible.constants as C
import time
import StringIO
import stat
import string
import termios
import tty
import pipes
@ -41,11 +40,6 @@ import warnings
import traceback
import getpass
import hmac
from Crypto.Cipher import
from Crypto import Random
from Crypto.Random.random import StrongRandom
VERBOSITY=0
MAX_FILE_SIZE_FOR_DIFF=1*1024*1024
@ -57,10 +51,8 @@ except ImportError:
try:
from hashlib import md5 as _md5
from hashlib import sha1 as _sha1
except ImportError:
from md5 import md5 as _md5
from sha1 import sha1 as _sha1
PASSLIB_AVAILABLE = False
try:
@ -69,128 +61,51 @@ try:
except:
pass
KEYCZAR_AVAILABLE=False
try:
import keyczar.errors as key_errors
from keyczar.keys import AesKey
KEYCZAR_AVAILABLE=True
except ImportError:
pass
###############################################################
# Abstractions around PyCrypto
# Abstractions around keyczar
###############################################################
class AES256Cipher(object):
"""
Class abstraction of an AES 256 cipher. This class
also keeps track of the time since the key was last
generated, so you know when to rekey. Rekeying would
be done as follows:
def key_for_hostname(hostname):
# fireball mode is an implementation of ansible firing up zeromq via SSH
# to use no persistent daemons or key management
k = AES256Cipher.gen_key()
<exchange new key with client securely>
AES26Cipher.set_key(k)
if not KEYCZAR_AVAILABLE:
raise errors.AnsibleError("python-keyczar must be installed to use fireball mode")
From this point on the new key would be used until
the lifetime is exceeded.
"""
def __init__(self, lifetime=60*30, mode=AES.MODE_CFB):
self.lifetime = lifetime
self.mode = mode
self.set_key(self.gen_key())
key_path = os.path.expanduser("~/.fireball.keys")
if not os.path.exists(key_path):
os.makedirs(key_path)
key_path = os.path.expanduser("~/.fireball.keys/%s" % hostname)
def gen_key(self):
"""
Generates a 256-bit (32 byte) key to be used for the
AES block encryption.
"""
return b"".join(StrongRandom().sample(string.letters+string.digits+string.punctuation,32))
def set_key(self,key):
"""
Sets the internal key to the one provided and resets the
internal time to now. This key should ONLY be set to one
generated by gen_key()
"""
self.init_time = time.time()
self.key = key
def should_rekey(self):
"""
Returns true if the lifetime of the current key has
exceeded the set lifetime.
"""
if (time.time() - self.init_time) > self.lifetime:
return True
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate()
fh = open(key_path, "w")
fh.write(str(key))
fh.close()
return key
else:
return False
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
def _pad(self, msg):
"""
Adds padding to the message so that it is a full
AES block size. Used during encryption of the message.
"""
pad = AES.block_size - len(msg) % AES.block_size
return msg + pad * chr(pad)
def encrypt(key, msg):
return key.Encrypt(msg)
def _unpad(self, msg):
"""
Strips out the padding that _pad added. Used during
the decryption of the message.
"""
pad = ord(msg[-1])
return msg[:-pad]
def gen_sig(self, msg):
"""
Generates an HMAC-SHA1 signature for the message
"""
return hmac.new(self.key, msg, _sha1).digest()
def validate_sig(self, msg, sig):
"""
Verifies the generated signature of the message matches
the signature provided.
"""
new_sig = self.gen_sig(msg)
return (new_sig == sig)
def encrypt(self, msg):
"""
Encrypt the message using AES. The signature
is appended to the end of the message and is
used to verify the integrity of the IV and data.
Returns a base64-encoded version of the following:
rval[0:16] = initialization vector
rval[16:-20] = cipher text
rval[-20:] = signature
"""
msg = self._pad(msg)
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, self.mode, iv)
data = iv + cipher.encrypt(msg)
sig = self.gen_sig(data)
return (data + sig).encode('base64')
def decrypt(self, msg):
"""
Decrypt the message using AES. The signature is
used to verify the IV and data before decoding to
ensure the integrity of the message. This is an
HMAC-SHA1 hash, so it is always 20 characters
The incoming message format (after base64 decoding)
is as follows:
msg[0:16] = initialization vector
msg[16:-20] = cipher text
msg[-20:] = signature (HMAC-SHA1)
Returns the plain-text of the cipher.
"""
msg = msg.decode('base64')
data = msg[0:-20] # iv + cipher text
msig = msg[-20:] # hmac-sha1 hash
if not self.validate_sig(data,msig):
raise Exception("Failed to validate the message signature")
iv = msg[:AES.block_size]
cipher = AES.new(self.key, self.mode, iv)
return self._unpad(cipher.decrypt(msg)[AES.block_size:])
def decrypt(key, msg):
try:
return key.Decrypt(msg)
except key_errors.InvalidSignatureError:
raise errors.AnsibleError("decryption failed")
###############################################################
# UTILITY FUNCTIONS FOR COMMAND LINE TOOLS

284
library/utilities/fireball2 Normal file
View file

@ -0,0 +1,284 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# (c) 2013, James Cammarata <jcammarata@ansibleworks.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
DOCUMENTATION = '''
---
module: fireball2
short_description: Enable fireball2 mode on remote node
description:
- This modules launches an ephemeral I(fireball2) daemon on the remote node which
Ansible can use to communicate with nodes at high speed.
- The daemon listens on a configurable port for a configurable amount of time.
- Starting a new fireball2 as a given user terminates any existing user fireballs2.
- Fireball mode is AES encrypted
version_added: "1.3"
options:
port:
description:
- TCP port for the socket connection
required: false
default: 5099
aliases: []
minutes:
description:
- The I(fireball2) listener daemon is started on nodes and will stay around for
this number of minutes before turning itself off.
required: false
default: 30
notes:
- See the advanced playbooks chapter for more about using fireball2 mode.
requirements: [ "pycrypto" ]
author: James Cammarata
'''
EXAMPLES = '''
# To use fireball2 mode, simple add "accelerated: true" to your play. The initial
# key exchange and starting up of the daemon will occur over SSH, but all commands and
# subsequent actions will be conducted over the raw socket connection using AES encryption
- hosts: devservers
accelerated: true
tasks:
- command: /usr/bin/anything
'''
import os
import sys
import shutil
import socket
import struct
import time
import base64
import syslog
import signal
import time
import signal
import traceback
import SocketServer
syslog.openlog('ansible-%s' % os.path.basename(__file__))
PIDFILE = os.path.expanduser("~/.fireball2.pid")
def log(msg):
syslog.syslog(syslog.LOG_NOTICE, msg)
if os.path.exists(PIDFILE):
try:
data = int(open(PIDFILE).read())
try:
os.kill(data, signal.SIGKILL)
except OSError:
pass
except ValueError:
pass
os.unlink(PIDFILE)
HAS_KEYCZAR = False
try:
from keyczar.keys import AesKey
HAS_KEYCZAR = True
except ImportError:
pass
# NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move
# this into utils.module_common and probably should anyway
def daemonize_self(module, password, port, minutes):
# daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012
try:
pid = os.fork()
if pid > 0:
log("exiting pid %s" % pid)
# exit first parent
module.exit_json(msg="daemonized fireball2 on port %s for %s minutes" % (port, minutes))
except OSError, e:
log("fork #1 failed: %d (%s)" % (e.errno, e.strerror))
sys.exit(1)
# decouple from parent environment
os.chdir("/")
os.setsid()
os.umask(022)
# do second fork
try:
pid = os.fork()
if pid > 0:
log("daemon pid %s, writing %s" % (pid, PIDFILE))
pid_file = open(PIDFILE, "w")
pid_file.write("%s" % pid)
pid_file.close()
log("pidfile written")
sys.exit(0)
except OSError, e:
log("fork #2 failed: %d (%s)" % (e.errno, e.strerror))
sys.exit(1)
dev_null = file('/dev/null','rw')
os.dup2(dev_null.fileno(), sys.stdin.fileno())
os.dup2(dev_null.fileno(), sys.stdout.fileno())
os.dup2(dev_null.fileno(), sys.stderr.fileno())
log("daemonizing successful")
#class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password):
self.module = module
self.key = AesKey.Read(password)
self.allow_reuse_address = True
self.timeout = None
SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass)
class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def send_data(self, data):
packed_len = struct.pack('Q', len(data))
return self.request.sendall(packed_len + data)
def recv_data(self):
header_len = 8 # size of a packed unsigned long long
data = b""
while len(data) < header_len:
d = self.request.recv(1024)
if not d:
return None
data += d
data_len = struct.unpack('Q',data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
data += self.request.recv(1024)
return data
def handle(self):
while True:
data = self.recv_data()
if not data:
break
try:
data = self.server.key.Decrypt(data)
except:
log("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1))
data2 = self.server.key.Encrypt(data2)
send_data(client, data2)
return
data = json.loads(data)
mode = data['mode']
response = {}
if mode == 'command':
response = self.command(data)
elif mode == 'put':
response = self.put(data)
elif mode == 'fetch':
response = self.fetch(data)
data2 = json.dumps(response)
data2 = self.server.key.Encrypt(data2)
self.send_data(data2)
def command(self, data):
if 'cmd' not in data:
return dict(failed=True, msg='internal error: cmd is required')
if 'tmp_path' not in data:
return dict(failed=True, msg='internal error: tmp_path is required')
if 'executable' not in data:
return dict(failed=True, msg='internal error: executable is required')
log("executing: %s" % data['cmd'])
rc, stdout, stderr = self.server.module.run_command(data['cmd'], executable=data['executable'], close_fds=True)
if stdout is None:
stdout = ''
if stderr is None:
stderr = ''
log("got stdout: %s" % stdout)
return dict(rc=rc, stdout=stdout, stderr=stderr)
def fetch(self, data):
if 'in_path' not in data:
return dict(failed=True, msg='internal error: in_path is required')
# FIXME: should probably support chunked file transfer for binary files
# at some point. For now, just base64 encodes the file
# so don't use it to move ISOs, use rsync.
fh = open(data['in_path'])
data = base64.b64encode(fh.read())
return dict(data=data)
def put(self, data):
if 'data' not in data:
return dict(failed=True, msg='internal error: data is required')
if 'out_path' not in data:
return dict(failed=True, msg='internal error: out_path is required')
# FIXME: should probably support chunked file transfer for binary files
# at some point. For now, just base64 encodes the file
# so don't use it to move ISOs, use rsync.
fh = open(data['out_path'], 'w')
fh.write(base64.b64decode(data['data']))
fh.close()
return dict()
def daemonize(module, password, port, minutes):
try:
daemonize_self(module, password, port, minutes)
def catcher(signum, _):
module.exit_json(msg='timer expired')
signal.signal(signal.SIGALRM, catcher)
signal.setitimer(signal.ITIMER_REAL, 60 * minutes)
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password)
server.allow_reuse_address = True
server.serve_forever(poll_interval=1.0)
except Exception, e:
tb = traceback.format_exc()
log("exception caught, exiting fireball mode: %s\n%s" % (e, tb))
sys.exit(0)
def main():
module = AnsibleModule(
argument_spec = dict(
port=dict(required=False, default=5099),
password=dict(required=True),
minutes=dict(required=False, default=30),
)
)
password = base64.b64decode(module.params['password'])
port = module.params['port']
minutes = int(module.params['minutes'])
if not HAS_KEYCZAR:
module.fail_json(msg="keyczar is not installed")
daemonize(module, password, port, minutes)
# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>
main()