Initial support for sudoable commands over fireball2

Caveats:
* requiretty must be disabled in the sudoers config
* asking for a password doesn't work yet, so any sudoers users must
  be configured with NOPASSWD
* if not starting the daemon as root, the user running the daemon
  must have sudoers entries to allow them to run the command as the
  target sudo_user
This commit is contained in:
James Cammarata 2013-08-27 13:12:35 -05:00
parent 4b552457e7
commit b45342923c
2 changed files with 83 additions and 30 deletions

View file

@ -61,15 +61,25 @@ class Connection(object):
def connect(self, allow_ssh=True): def connect(self, allow_ssh=True):
''' activates the connection object ''' ''' activates the connection object '''
if self.is_connected:
return self
try: try:
self.conn = socket.socket() if not self.is_connected:
self.conn.connect((self.host,self.fbport)) # TODO: make the timeout and retries configurable?
tries = 10
self.conn = socket.socket()
self.conn.settimeout(30.0)
while tries > 0:
try:
self.conn.connect((self.host,self.fbport))
break
except:
time.sleep(0.1)
tries -= 1
if tries == 0:
vvv("Could not connect via the fireball2 connection, exceeded # of tries")
raise errors.AnsibleError("Failed to connect")
except: except:
if allow_ssh: if allow_ssh:
print "Falling back to ssh to startup accelerated mode" vvv("Falling back to ssh to startup accelerated mode")
res = self._execute_fb_module() res = self._execute_fb_module()
return self.connect(allow_ssh=False) return self.connect(allow_ssh=False)
else: else:
@ -84,23 +94,29 @@ class Connection(object):
def recv_data(self): def recv_data(self):
header_len = 8 # size of a packed unsigned long long header_len = 8 # size of a packed unsigned long long
data = b"" data = b""
while len(data) < header_len: try:
d = self.conn.recv(1024) while len(data) < header_len:
if not d: d = self.conn.recv(1024)
return None if not d:
data += d return None
data_len = struct.unpack('Q',data[:header_len])[0] data += d
data = data[header_len:] data_len = struct.unpack('Q',data[:header_len])[0]
while len(data) < data_len: data = data[header_len:]
d = self.conn.recv(1024) while len(data) < data_len:
if not d: d = self.conn.recv(1024)
return None if not d:
data += d return None
return data data += d
return data
except socket.timeout:
raise errors.AnsibleError("timed out while waiting to receive data")
def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'): def exec_command(self, cmd, tmp_path, sudo_user, sudoable=False, executable='/bin/sh'):
''' run a command on the remote host ''' ''' run a command on the remote host '''
if self.runner.sudo or sudoable and sudo_user:
cmd, prompt = utils.make_sudo_cmd(sudo_user, executable, cmd)
vvv("EXEC COMMAND %s" % cmd) vvv("EXEC COMMAND %s" % cmd)
data = dict( data = dict(
@ -112,12 +128,15 @@ class Connection(object):
data = utils.jsonify(data) data = utils.jsonify(data)
data = utils.encrypt(self.key, data) data = utils.encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise errors.AnisbleError("Failed to send command to %s:%s" % (self.host,self.port)) raise errors.AnisbleError("Failed to send command to %s" % self.host)
response = self.recv_data() response = self.recv_data()
if not response:
raise errors.AnsibleError("Failed to get a response from %s" % self.host)
response = utils.decrypt(self.key, response) response = utils.decrypt(self.key, response)
response = utils.parse_json(response) response = utils.parse_json(response)
vvv("COMMAND DONE: rc=%s" % str(response.get('rc',"<unknown>")))
return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr','')) return (response.get('rc',None), '', response.get('stdout',''), response.get('stderr',''))
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
@ -132,17 +151,23 @@ class Connection(object):
data = base64.b64encode(data) data = base64.b64encode(data)
data = dict(mode='put', data=data, out_path=out_path) data = dict(mode='put', data=data, out_path=out_path)
if self.runner.sudo:
data['user'] = self.runner.sudo_user
# TODO: support chunked file transfer # TODO: support chunked file transfer
data = utils.jsonify(data) data = utils.jsonify(data)
data = utils.encrypt(self.key, data) data = utils.encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise errors.AnsibleError("failed to send the file to %s:%s" % (self.host,self.port)) raise errors.AnsibleError("failed to send the file to %s" % self.host)
response = self.recv_data() response = self.recv_data()
response = utils.decrypt(self.key, data) if not response:
raise errors.AnsibleError("Failed to get a response from %s" % self.host)
response = utils.decrypt(self.key, response)
response = utils.parse_json(response) response = utils.parse_json(response)
# no meaningful response needed for this if response.get('failed',False):
raise errors.AnsibleError("failed to put the file in the requested location")
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' save a remote file to the specified path ''' ''' save a remote file to the specified path '''
@ -152,10 +177,12 @@ class Connection(object):
data = utils.jsonify(data) data = utils.jsonify(data)
data = utils.encrypt(self.key, data) data = utils.encrypt(self.key, data)
if self.send_data(data): if self.send_data(data):
raise errors.AnsibleError("failed to initiate the file fetch with %s:%s" % (self.host,self.port)) raise errors.AnsibleError("failed to initiate the file fetch with %s" % self.host)
response = self.recv_data() response = self.recv_data()
response = utils.decrypt(self.key, data) if not response:
raise errors.AnsibleError("Failed to get a response from %s" % self.host)
response = utils.decrypt(self.key, response)
response = utils.parse_json(response) response = utils.parse_json(response)
response = response['data'] response = response['data']
response = base64.b64decode(response) response = base64.b64decode(response)

View file

@ -60,12 +60,15 @@ EXAMPLES = '''
''' '''
import os import os
import os.path
import tempfile
import sys import sys
import shutil import shutil
import socket import socket
import struct import struct
import time import time
import base64 import base64
import getpass
import syslog import syslog
import signal import signal
import time import time
@ -138,7 +141,6 @@ def daemonize_self(module, password, port, minutes):
os.dup2(dev_null.fileno(), sys.stderr.fileno()) os.dup2(dev_null.fileno(), sys.stderr.fileno())
log("daemonizing successful") log("daemonizing successful")
#class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
class ThreadedTCPServer(SocketServer.ThreadingTCPServer): class ThreadedTCPServer(SocketServer.ThreadingTCPServer):
def __init__(self, server_address, RequestHandlerClass, module, password): def __init__(self, server_address, RequestHandlerClass, module, password):
self.module = module self.module = module
@ -171,11 +173,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
def handle(self): def handle(self):
while True: while True:
#log("waiting for data")
data = self.recv_data() data = self.recv_data()
if not data: if not data:
break break
try: try:
#log("got data, decrypting")
data = self.server.key.Decrypt(data) data = self.server.key.Decrypt(data)
#log("decryption done")
except: except:
log("bad decrypt, skipping...") log("bad decrypt, skipping...")
data2 = json.dumps(dict(rc=1)) data2 = json.dumps(dict(rc=1))
@ -183,6 +188,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
send_data(client, data2) send_data(client, data2)
return return
#log("loading json from the data")
data = json.loads(data) data = json.loads(data)
mode = data['mode'] mode = data['mode']
@ -212,7 +218,8 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
stdout = '' stdout = ''
if stderr is None: if stderr is None:
stderr = '' stderr = ''
log("got stdout: %s" % stdout) #log("got stdout: %s" % stdout)
#log("got stderr: %s" % stderr)
return dict(rc=rc, stdout=stdout, stderr=stderr) return dict(rc=rc, stdout=stdout, stderr=stderr)
@ -234,14 +241,32 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler):
if 'out_path' not in data: if 'out_path' not in data:
return dict(failed=True, msg='internal error: out_path is required') return dict(failed=True, msg='internal error: out_path is required')
final_path = None
if 'user' in data and data.get('user') != getpass.getuser():
log("the target user doesn't match this user, we'll move the file into place via sudo")
(fd,out_path) = tempfile.mkstemp(prefix='ansible.', dir=os.path.expanduser('~/.ansible/tmp/'))
out_fd = os.fdopen(fd, 'w', 0)
final_path = data['out_path']
else:
out_path = data['out_path']
out_fd = open(out_path, 'w')
# FIXME: should probably support chunked file transfer for binary files # FIXME: should probably support chunked file transfer for binary files
# at some point. For now, just base64 encodes the file # at some point. For now, just base64 encodes the file
# so don't use it to move ISOs, use rsync. # so don't use it to move ISOs, use rsync.
fh = open(data['out_path'], 'w') try:
fh.write(base64.b64decode(data['data'])) out_fd.write(base64.b64decode(data['data']))
fh.close() out_fd.close()
except:
return dict(failed=True, stdout="Could not write the file")
if final_path:
log("moving %s to %s" % (out_path, final_path))
args = ['sudo','mv',out_path,final_path]
rc, stdout, stderr = self.server.module.run_command(args, close_fds=True)
if rc != 0:
return dict(failed=True, stdout="failed to copy the file into position with sudo")
return dict() return dict()
def daemonize(module, password, port, minutes): def daemonize(module, password, port, minutes):
@ -257,6 +282,7 @@ def daemonize(module, password, port, minutes):
server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password) server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password)
server.allow_reuse_address = True server.allow_reuse_address = True
log("serving!")
server.serve_forever(poll_interval=1.0) server.serve_forever(poll_interval=1.0)
except Exception, e: except Exception, e:
tb = traceback.format_exc() tb = traceback.format_exc()