diff --git a/utilities/accelerate b/utilities/accelerate index a61e54e374d..5a8c96c64a9 100644 --- a/utilities/accelerate +++ b/utilities/accelerate @@ -53,6 +53,14 @@ options: if this parameter is set to true. required: false default: false + multi_key: + description: + - When enabled, the daemon will open a local socket file which can be used by future daemon executions to + upload a new key to the already running daemon, so that multiple users can connect using different keys. + This access still requires an ssh connection as the uid for which the daemon is currently running. + required: false + default: no + version_added: "1.6" notes: - See the advanced playbooks chapter for more about using accelerated mode. requirements: [ "python-keyczar" ] @@ -71,6 +79,7 @@ EXAMPLES = ''' ''' import base64 +import errno import getpass import json import os @@ -88,10 +97,13 @@ import traceback import SocketServer from datetime import datetime -from threading import Thread +from threading import Thread, Lock + +# import module snippets +# we must import this here at the top so we can use get_module_path() +from ansible.module_utils.basic import * syslog.openlog('ansible-%s' % os.path.basename(__file__)) -PIDFILE = os.path.expanduser("~/.accelerate.pid") # the chunk size to read and send, assuming mtu 1500 and # leaving room for base64 (+33%) encoding and header (100 bytes) @@ -107,6 +119,9 @@ def log(msg, cap=0): if DEBUG_LEVEL >= cap: syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg) +def v(msg): + log(msg, cap=1) + def vv(msg): log(msg, cap=2) @@ -116,16 +131,6 @@ def vvv(msg): def vvvv(msg): log(msg, cap=4) -if os.path.exists(PIDFILE): - try: - data = int(open(PIDFILE).read()) - try: - os.kill(data, signal.SIGKILL) - except OSError: - pass - except ValueError: - pass - os.unlink(PIDFILE) HAS_KEYCZAR = False try: @@ -134,10 +139,26 @@ try: except ImportError: pass +SOCKET_FILE = os.path.join(get_module_path(), '.ansible-accelerate', ".local.socket") + +def get_pid_location(module): + """ + Try to find a pid directory in the common locations, falling + back to the user's home directory if no others exist + """ + for dir in ['/var/run', '/var/lib/run', '/run', os.path.expanduser("~/")]: + try: + if os.path.isdir(dir) and os.access(dir, os.R_OK|os.W_OK): + return os.path.join(dir, '.accelerate.pid') + except: + pass + module.fail_json(msg="couldn't find any valid directory to use for the accelerate pid file") + + # NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move # this into utils.module_common and probably should anyway -def daemonize_self(module, password, port, minutes): +def daemonize_self(module, password, port, minutes, pid_file): # daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 try: pid = os.fork() @@ -158,11 +179,11 @@ def daemonize_self(module, password, port, minutes): try: pid = os.fork() if pid > 0: - log("daemon pid %s, writing %s" % (pid, PIDFILE)) - pid_file = open(PIDFILE, "w") + log("daemon pid %s, writing %s" % (pid, pid_file)) + pid_file = open(pid_file, "w") pid_file.write("%s" % pid) pid_file.close() - vvv("pidfile written") + vvv("pid file written") sys.exit(0) except OSError, e: log("fork #2 failed: %d (%s)" % (e.errno, e.strerror)) @@ -174,8 +195,85 @@ def daemonize_self(module, password, port, minutes): os.dup2(dev_null.fileno(), sys.stderr.fileno()) log("daemonizing successful") -class ThreadWithReturnValue(Thread): +class LocalSocketThread(Thread): + server = None + terminated = False + def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None): + self.server = kwargs.get('server') + Thread.__init__(self, group, target, name, args, kwargs, Verbose) + + def run(self): + try: + if os.path.exists(SOCKET_FILE): + os.remove(SOCKET_FILE) + else: + dir = os.path.dirname(SOCKET_FILE) + if os.path.exists(dir): + if not os.path.isdir(dir): + log("The socket file path (%s) exists, but is not a directory. No local connections will be available" % dir) + return + else: + # make sure the directory is accessible only to this + # user, as socket files derive their permissions from + # the directory that contains them + os.chmod(dir, 0700) + elif not os.path.exists(dir): + os.makedirs(dir, 0700) + except OSError: + pass + self.s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.s.bind(SOCKET_FILE) + self.s.listen(5) + while not self.terminated: + try: + conn, addr = self.s.accept() + vv("received local connection") + data = "" + while "\n" not in data: + data += conn.recv(2048) + try: + new_key = AesKey.Read(data.strip()) + found = False + for key in self.server.key_list: + try: + new_key.Decrypt(key.Encrypt("foo")) + found = True + break + except: + pass + if not found: + vv("adding new key to the key list") + self.server.key_list.append(new_key) + conn.sendall("OK\n") + else: + vv("key already exists in the key list, ignoring") + conn.sendall("EXISTS\n") + + # update the last event time so the server doesn't + # shutdown sooner than expected for new cliets + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + except Exception, e: + vv("key loaded locally was invalid, ignoring (%s)" % e) + conn.sendall("BADKEY\n") + finally: + try: + conn.close() + except: + pass + except: + pass + + def terminate(self): + self.terminated = True + self.s.shutdown(socket.SHUT_RDWR) + self.s.close() + +class ThreadWithReturnValue(Thread): def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None): Thread.__init__(self, group, target, name, args, kwargs, Verbose) self._return = None @@ -190,24 +288,41 @@ class ThreadWithReturnValue(Thread): return self._return class ThreadedTCPServer(SocketServer.ThreadingTCPServer): - def __init__(self, server_address, RequestHandlerClass, module, password, timeout): + key_list = [] + last_event = datetime.now() + last_event_lock = Lock() + def __init__(self, server_address, RequestHandlerClass, module, password, timeout, use_ipv6=False): self.module = module - self.key = AesKey.Read(password) + self.key_list.append(AesKey.Read(password)) self.allow_reuse_address = True self.timeout = timeout + + if use_ipv6: + self.address_family = socket.AF_INET6 + + if self.module.params.get('multi_key', False): + vv("starting thread to handle local connections for multiple keys") + self.local_thread = LocalSocketThread(kwargs=dict(server=self)) + self.local_thread.start() + SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) -class ThreadedTCPV6Server(SocketServer.ThreadingTCPServer): - def __init__(self, server_address, RequestHandlerClass, module, password, timeout): - self.module = module - self.address_family = socket.AF_INET6 - self.key = AesKey.Read(password) - self.allow_reuse_address = True - self.timeout = timeout - SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) + def shutdown(self): + self.local_thread.terminate() + self.running = False + SocketServer.ThreadingTCPServer.shutdown(self) class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): + # the key to use for this connection + active_key = None + def send_data(self, data): + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + packed_len = struct.pack('!Q', len(data)) return self.request.sendall(packed_len + data) @@ -216,23 +331,40 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): data = "" vvvv("in recv_data(), waiting for the header") while len(data) < header_len: - d = self.request.recv(header_len - len(data)) - if not d: - vvv("received nothing, bailing out") + try: + d = self.request.recv(header_len - len(data)) + if not d: + vvv("received nothing, bailing out") + return None + data += d + except: + # probably got a connection reset + vvvv("exception received while waiting for recv(), returning None") return None - data += d vvvv("in recv_data(), got the header, unpacking") data_len = struct.unpack('!Q',data[:header_len])[0] data = data[header_len:] vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) while len(data) < data_len: - d = self.request.recv(data_len - len(data)) - if not d: - vvv("received nothing, bailing out") + try: + d = self.request.recv(data_len - len(data)) + if not d: + vvv("received nothing, bailing out") + return None + data += d + vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) + except: + # probably got a connection reset + vvvv("exception received while waiting for recv(), returning None") return None - data += d - vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) vvvv("received all of the data, returning") + + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + return data def handle(self): @@ -243,18 +375,26 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): if not data: vvvv("received nothing back from recv_data(), breaking out") break - try: - vvvv("got data, decrypting") - data = self.server.key.Decrypt(data) - vvvv("decryption done") - except: - vv("bad decrypt, skipping...") - data2 = json.dumps(dict(rc=1)) - data2 = self.server.key.Encrypt(data2) - self.send_data(data2) - return + vvvv("got data, decrypting") + if not self.active_key: + for key in self.server.key_list: + try: + data = key.Decrypt(data) + self.active_key = key + break + except: + pass + else: + vv("bad decrypt, exiting the connection handler") + return + else: + try: + data = self.active_key.Decrypt(data) + except: + vv("bad decrypt, exiting the connection handler") + return - vvvv("loading json from the data") + vvvv("decryption done, loading json from the data") data = json.loads(data) mode = data['mode'] @@ -270,7 +410,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): last_pong = datetime.now() vvvv("command still running, sending keepalive packet") data2 = json.dumps(dict(pong=True)) - data2 = self.server.key.Encrypt(data2) + data2 = self.active_key.Encrypt(data2) self.send_data(data2) time.sleep(0.1) response = twrv._return @@ -286,8 +426,9 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): response = self.validate_user(data) vvvv("response result is %s" % str(response)) - data2 = json.dumps(response) - data2 = self.server.key.Encrypt(data2) + json_response = json.dumps(response) + vvvv("dumped json is %s" % json_response) + data2 = self.active_key.Encrypt(json_response) vvvv("sending the response back to the controller") self.send_data(data2) vvvv("done sending the response") @@ -299,9 +440,10 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): tb = traceback.format_exc() log("encountered an unhandled exception in the handle() function") log("error was:\n%s" % tb) - data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function")) - data2 = self.server.key.Encrypt(data2) - self.send_data(data2) + if self.active_key: + data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function")) + data2 = self.active_key.Encrypt(data2) + self.send_data(data2) def validate_user(self, data): if 'username' not in data: @@ -362,7 +504,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): last = True data = dict(data=base64.b64encode(data), last=last) data = json.dumps(data) - data = self.server.key.Encrypt(data) + data = self.active_key.Encrypt(data) if self.send_data(data): return dict(failed=True, stderr="failed to send data") @@ -371,7 +513,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): if not response: log("failed to get a response, aborting") return dict(failed=True, stderr="Failed to get a response from %s" % self.host) - response = self.server.key.Decrypt(response) + response = self.active_key.Decrypt(response) response = json.loads(response) if response.get('failed',False): @@ -394,7 +536,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): final_path = None if 'user' in data and data.get('user') != getpass.getuser(): - vv("the target user doesn't match this user, we'll move the file into place via sudo") + vvv("the target user doesn't match this user, we'll move the file into place via sudo") tmp_path = os.path.expanduser('~/.ansible/tmp/') if not os.path.exists(tmp_path): try: @@ -415,14 +557,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): bytes += len(out) out_fd.write(out) response = json.dumps(dict()) - response = self.server.key.Encrypt(response) + response = self.active_key.Encrypt(response) self.send_data(response) if data['last']: break data = self.recv_data() if not data: raise "" - data = self.server.key.Decrypt(data) + data = self.active_key.Decrypt(data) data = json.loads(data) except: out_fd.close() @@ -438,27 +580,45 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): self.server.module.atomic_move(out_path, final_path) return dict() -def daemonize(module, password, port, timeout, minutes, ipv6): +def daemonize(module, password, port, timeout, minutes, use_ipv6, pid_file): try: - daemonize_self(module, password, port, minutes) + daemonize_self(module, password, port, minutes, pid_file) - def catcher(signum, _): - module.exit_json(msg='timer expired') + def timer_handler(signum, _): + try: + server.last_event_lock.acquire() + td = datetime.now() - server.last_event + # older python timedelta objects don't have total_seconds(), + # so we use the formula from the docs to calculate it + total_seconds = (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 + if total_seconds >= minutes * 60: + log("server has been idle longer than the timeout, shutting down") + server.running = False + server.shutdown() + else: + # reschedule the check + vvvv("daemon idle for %d seconds (timeout=%d)" % (total_seconds,minutes*60)) + signal.alarm(30) + except: + pass + finally: + server.last_event_lock.release() - signal.signal(signal.SIGALRM, catcher) - signal.setitimer(signal.ITIMER_REAL, 60 * minutes) + signal.signal(signal.SIGALRM, timer_handler) + signal.alarm(30) tries = 5 while tries > 0: try: - if ipv6: - server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout) + if use_ipv6: + address = ("::", port) else: - server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout) + address = ("0.0.0.0", port) + server = ThreadedTCPServer(address, ThreadedTCPRequestHandler, module, password, timeout, use_ipv6=use_ipv6) server.allow_reuse_address = True break - except: - vv("Failed to create the TCP server (tries left = %d)" % tries) + except Exception, e: + vv("Failed to create the TCP server (tries left = %d) (error: %s) " % (tries,e)) tries -= 1 time.sleep(0.2) @@ -466,8 +626,20 @@ def daemonize(module, password, port, timeout, minutes, ipv6): vv("Maximum number of attempts to create the TCP server reached, bailing out") raise Exception("max # of attempts to serve reached") - vv("serving!") - server.serve_forever(poll_interval=0.1) + # run the server in a separate thread to make signal handling work + server_thread = Thread(target=server.serve_forever, kwargs=dict(poll_interval=0.1)) + server_thread.start() + server.running = True + + v("serving!") + while server.running: + time.sleep(1) + + # wait for the thread to exit fully + server_thread.join() + + v("server thread terminated, exiting!") + sys.exit(0) except Exception, e: tb = traceback.format_exc() log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb)) @@ -479,6 +651,7 @@ def main(): argument_spec = dict( port=dict(required=False, default=5099), ipv6=dict(required=False, default=False, type='bool'), + multi_key=dict(required=False, default=False, type='bool'), timeout=dict(required=False, default=300), password=dict(required=True), minutes=dict(required=False, default=30), @@ -493,14 +666,62 @@ def main(): minutes = int(module.params['minutes']) debug = int(module.params['debug']) ipv6 = module.params['ipv6'] + multi_key = module.params['multi_key'] if not HAS_KEYCZAR: module.fail_json(msg="keyczar is not installed (on the remote side)") DEBUG_LEVEL=debug + pid_file = get_pid_location(module) - daemonize(module, password, port, timeout, minutes, ipv6) + daemon_pid = None + daemon_running = False + if os.path.exists(pid_file): + try: + daemon_pid = int(open(pid_file).read()) + try: + # sending signal 0 doesn't do anything to the + # process, other than tell the calling program + # whether other signals can be sent + os.kill(daemon_pid, 0) + except OSError, e: + if e.errno == errno.EPERM: + # no permissions means the pid is probably + # running, but as a different user, so fail + module.fail_json(msg="the accelerate daemon appears to be running as a different user that this user cannot access (pid=%d)" % daemon_pid) + else: + daemon_running = True + except ValueError: + # invalid pid file, unlink it - otherwise we don't care + try: + os.unlink(pid_file) + except: + pass + + if daemon_running and multi_key: + # try to connect to the file socket for the daemon if it exists + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + s.connect(SOCKET_FILE) + s.sendall(password + '\n') + data = "" + while '\n' not in data: + data += s.recv(2048) + res = data.strip() + except: + module.fail_json(msg="failed to connect to the local socket file") + finally: + try: + s.close() + except: + pass + + if res in ("OK", "EXISTS"): + module.exit_json(msg="transferred new key to the existing daemon") + else: + module.fail_json(msg="could not transfer new key: %s" % data.strip()) + else: + # try to start up the daemon + daemonize(module, password, port, timeout, minutes, ipv6, pid_file) -# import module snippets -from ansible.module_utils.basic import * main()