diff --git a/lib/ansible/constants.py b/lib/ansible/constants.py index 7d0b270beb7..a9b4c29af06 100644 --- a/lib/ansible/constants.py +++ b/lib/ansible/constants.py @@ -165,6 +165,7 @@ ACCELERATE_CONNECT_TIMEOUT = get_config(p, 'accelerate', 'accelerate_connect ACCELERATE_KEYS_DIR = get_config(p, 'accelerate', 'accelerate_keys_dir', 'ACCELERATE_KEYS_DIR', '~/.fireball.keys') ACCELERATE_KEYS_DIR_PERMS = get_config(p, 'accelerate', 'accelerate_keys_dir_perms', 'ACCELERATE_KEYS_DIR_PERMS', '700') ACCELERATE_KEYS_FILE_PERMS = get_config(p, 'accelerate', 'accelerate_keys_file_perms', 'ACCELERATE_KEYS_FILE_PERMS', '600') +ACCELERATE_MULTI_KEY = get_config(p, 'accelerate', 'accelerate_multi_key', 'ACCELERATE_MULTI_KEY', False, boolean=True) PARAMIKO_PTY = get_config(p, 'paramiko_connection', 'pty', 'ANSIBLE_PARAMIKO_PTY', True, boolean=True) # characters included in auto-generated passwords diff --git a/lib/ansible/runner/connection_plugins/accelerate.py b/lib/ansible/runner/connection_plugins/accelerate.py index 60c1319262a..8cc29fab9dc 100644 --- a/lib/ansible/runner/connection_plugins/accelerate.py +++ b/lib/ansible/runner/connection_plugins/accelerate.py @@ -22,10 +22,10 @@ import socket import struct import time from ansible.callbacks import vvv, vvvv +from ansible.errors import AnsibleError, AnsibleFileNotFound from ansible.runner.connection_plugins.ssh import Connection as SSHConnection from ansible.runner.connection_plugins.paramiko_ssh import Connection as ParamikoConnection from ansible import utils -from ansible import errors from ansible import constants # the chunk size to read and send, assuming mtu 1500 and @@ -85,7 +85,9 @@ class Connection(object): utils.AES_KEYS = self.runner.aes_keys def _execute_accelerate_module(self): - args = "password=%s port=%s debug=%d ipv6=%s" % (base64.b64encode(self.key.__str__()), str(self.accport), int(utils.VERBOSITY), self.runner.accelerate_ipv6) + args = "password=%s port=%s minutes=%d debug=%d ipv6=%s" % (base64.b64encode(self.key.__str__()), str(self.accport), constants.ACCELERATE_TIMEOUT, int(utils.VERBOSITY), self.runner.accelerate_ipv6) + if constants.ACCELERATE_MULTI_KEY: + args += " multi_key=yes" inject = dict(password=self.key) if getattr(self.runner, 'accelerate_inventory_host', False): inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.runner.accelerate_inventory_host)) @@ -109,33 +111,38 @@ class Connection(object): while tries > 0: try: self.conn.connect((self.host,self.accport)) - if not self.validate_user(): - # the accelerated daemon was started with a - # different remote_user. The above command - # should have caused the accelerate daemon to - # shutdown, so we'll reconnect. - wrong_user = True break - except: - vvvv("failed, retrying...") + except socket.error: + vvvv("connection to %s failed, retrying..." % self.host) time.sleep(0.1) tries -= 1 if tries == 0: vvv("Could not connect via the accelerated connection, exceeded # of tries") - raise errors.AnsibleError("Failed to connect") + raise AnsibleError("FAILED") elif wrong_user: vvv("Restarting daemon with a different remote_user") - raise errors.AnsibleError("Wrong user") + raise AnsibleError("WRONG_USER") + self.conn.settimeout(constants.ACCELERATE_TIMEOUT) - except: + if not self.validate_user(): + # the accelerated daemon was started with a + # different remote_user. The above command + # should have caused the accelerate daemon to + # shutdown, so we'll reconnect. + wrong_user = True + + except AnsibleError, e: if allow_ssh: + if "WRONG_USER" in e: + vvv("Switching users, waiting for the daemon on %s to shutdown completely..." % self.host) + time.sleep(5) vvv("Falling back to ssh to startup accelerated mode") res = self._execute_accelerate_module() if not res.is_successful(): - raise errors.AnsibleError("Failed to launch the accelerated daemon on %s (reason: %s)" % (self.host,res.result.get('msg'))) + raise AnsibleError("Failed to launch the accelerated daemon on %s (reason: %s)" % (self.host,res.result.get('msg'))) return self.connect(allow_ssh=False) else: - raise errors.AnsibleError("Failed to connect to %s:%s" % (self.host,self.accport)) + raise AnsibleError("Failed to connect to %s:%s" % (self.host,self.accport)) self.is_connected = True return self @@ -163,11 +170,12 @@ class Connection(object): if not d: vvvv("%s: received nothing, bailing out" % self.host) return None + vvvv("%s: received %d bytes" % (self.host, len(d))) data += d vvvv("%s: received all of the data, returning" % self.host) return data except socket.timeout: - raise errors.AnsibleError("timed out while waiting to receive data") + raise AnsibleError("timed out while waiting to receive data") def validate_user(self): ''' @@ -176,6 +184,7 @@ class Connection(object): daemon to exit if they don't match ''' + vvvv("%s: sending request for validate_user" % self.host) data = dict( mode='validate_user', username=self.user, @@ -183,15 +192,16 @@ class Connection(object): data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("Failed to send command to %s" % self.host) + raise AnsibleError("Failed to send command to %s" % self.host) + vvvv("%s: waiting for validate_user response" % self.host) while True: # we loop here while waiting for the response, because a # long running command may cause us to receive keepalive packets # ({"pong":"true"}) rather than the response we want. response = self.recv_data() if not response: - raise errors.AnsibleError("Failed to get a response from %s" % self.host) + raise AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) if "pong" in response: @@ -199,11 +209,11 @@ class Connection(object): vvvv("%s: received a keepalive packet" % self.host) continue else: - vvvv("%s: received the response" % self.host) + vvvv("%s: received the validate_user response: %s" % (self.host, response)) break if response.get('failed'): - raise errors.AnsibleError("Error while validating user: %s" % response.get("msg")) + return False else: return response.get('rc') == 0 @@ -211,10 +221,10 @@ class Connection(object): ''' run a command on the remote host ''' if su or su_user: - raise errors.AnsibleError("Internal Error: this module does not support running commands via su") + raise AnsibleError("Internal Error: this module does not support running commands via su") if in_data: - raise errors.AnsibleError("Internal Error: this module does not support optimized module pipelining") + raise AnsibleError("Internal Error: this module does not support optimized module pipelining") if executable == "": executable = constants.DEFAULT_EXECUTABLE @@ -233,7 +243,7 @@ class Connection(object): data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("Failed to send command to %s" % self.host) + raise AnsibleError("Failed to send command to %s" % self.host) while True: # we loop here while waiting for the response, because a @@ -241,7 +251,7 @@ class Connection(object): # ({"pong":"true"}) rather than the response we want. response = self.recv_data() if not response: - raise errors.AnsibleError("Failed to get a response from %s" % self.host) + raise AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) if "pong" in response: @@ -260,7 +270,7 @@ class Connection(object): vvv("PUT %s TO %s" % (in_path, out_path), host=self.host) if not os.path.exists(in_path): - raise errors.AnsibleFileNotFound("file or module does not exist: %s" % in_path) + raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) fd = file(in_path, 'rb') fstat = os.stat(in_path) @@ -279,27 +289,27 @@ class Connection(object): data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("failed to send the file to %s" % self.host) + raise AnsibleError("failed to send the file to %s" % self.host) response = self.recv_data() if not response: - raise errors.AnsibleError("Failed to get a response from %s" % self.host) + raise AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) if response.get('failed',False): - raise errors.AnsibleError("failed to put the file in the requested location") + raise AnsibleError("failed to put the file in the requested location") finally: fd.close() vvvv("waiting for final response after PUT") response = self.recv_data() if not response: - raise errors.AnsibleError("Failed to get a response from %s" % self.host) + raise AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) if response.get('failed',False): - raise errors.AnsibleError("failed to put the file in the requested location") + raise AnsibleError("failed to put the file in the requested location") def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' @@ -309,7 +319,7 @@ class Connection(object): data = utils.jsonify(data) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("failed to initiate the file fetch with %s" % self.host) + raise AnsibleError("failed to initiate the file fetch with %s" % self.host) fh = open(out_path, "w") try: @@ -317,11 +327,11 @@ class Connection(object): while True: response = self.recv_data() if not response: - raise errors.AnsibleError("Failed to get a response from %s" % self.host) + raise AnsibleError("Failed to get a response from %s" % self.host) response = utils.decrypt(self.key, response) response = utils.parse_json(response) if response.get('failed', False): - raise errors.AnsibleError("Error during file fetch, aborting") + raise AnsibleError("Error during file fetch, aborting") out = base64.b64decode(response['data']) fh.write(out) bytes += len(out) @@ -330,7 +340,7 @@ class Connection(object): data = utils.jsonify(dict()) data = utils.encrypt(self.key, data) if self.send_data(data): - raise errors.AnsibleError("failed to send ack during file fetch") + raise AnsibleError("failed to send ack during file fetch") if response.get('last', False): break finally: diff --git a/library/utilities/accelerate b/library/utilities/accelerate index a61e54e374d..5a8c96c64a9 100644 --- a/library/utilities/accelerate +++ b/library/utilities/accelerate @@ -53,6 +53,14 @@ options: if this parameter is set to true. required: false default: false + multi_key: + description: + - When enabled, the daemon will open a local socket file which can be used by future daemon executions to + upload a new key to the already running daemon, so that multiple users can connect using different keys. + This access still requires an ssh connection as the uid for which the daemon is currently running. + required: false + default: no + version_added: "1.6" notes: - See the advanced playbooks chapter for more about using accelerated mode. requirements: [ "python-keyczar" ] @@ -71,6 +79,7 @@ EXAMPLES = ''' ''' import base64 +import errno import getpass import json import os @@ -88,10 +97,13 @@ import traceback import SocketServer from datetime import datetime -from threading import Thread +from threading import Thread, Lock + +# import module snippets +# we must import this here at the top so we can use get_module_path() +from ansible.module_utils.basic import * syslog.openlog('ansible-%s' % os.path.basename(__file__)) -PIDFILE = os.path.expanduser("~/.accelerate.pid") # the chunk size to read and send, assuming mtu 1500 and # leaving room for base64 (+33%) encoding and header (100 bytes) @@ -107,6 +119,9 @@ def log(msg, cap=0): if DEBUG_LEVEL >= cap: syslog.syslog(syslog.LOG_NOTICE|syslog.LOG_DAEMON, msg) +def v(msg): + log(msg, cap=1) + def vv(msg): log(msg, cap=2) @@ -116,16 +131,6 @@ def vvv(msg): def vvvv(msg): log(msg, cap=4) -if os.path.exists(PIDFILE): - try: - data = int(open(PIDFILE).read()) - try: - os.kill(data, signal.SIGKILL) - except OSError: - pass - except ValueError: - pass - os.unlink(PIDFILE) HAS_KEYCZAR = False try: @@ -134,10 +139,26 @@ try: except ImportError: pass +SOCKET_FILE = os.path.join(get_module_path(), '.ansible-accelerate', ".local.socket") + +def get_pid_location(module): + """ + Try to find a pid directory in the common locations, falling + back to the user's home directory if no others exist + """ + for dir in ['/var/run', '/var/lib/run', '/run', os.path.expanduser("~/")]: + try: + if os.path.isdir(dir) and os.access(dir, os.R_OK|os.W_OK): + return os.path.join(dir, '.accelerate.pid') + except: + pass + module.fail_json(msg="couldn't find any valid directory to use for the accelerate pid file") + + # NOTE: this shares a fair amount of code in common with async_wrapper, if async_wrapper were a new module we could move # this into utils.module_common and probably should anyway -def daemonize_self(module, password, port, minutes): +def daemonize_self(module, password, port, minutes, pid_file): # daemonizing code: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 try: pid = os.fork() @@ -158,11 +179,11 @@ def daemonize_self(module, password, port, minutes): try: pid = os.fork() if pid > 0: - log("daemon pid %s, writing %s" % (pid, PIDFILE)) - pid_file = open(PIDFILE, "w") + log("daemon pid %s, writing %s" % (pid, pid_file)) + pid_file = open(pid_file, "w") pid_file.write("%s" % pid) pid_file.close() - vvv("pidfile written") + vvv("pid file written") sys.exit(0) except OSError, e: log("fork #2 failed: %d (%s)" % (e.errno, e.strerror)) @@ -174,8 +195,85 @@ def daemonize_self(module, password, port, minutes): os.dup2(dev_null.fileno(), sys.stderr.fileno()) log("daemonizing successful") -class ThreadWithReturnValue(Thread): +class LocalSocketThread(Thread): + server = None + terminated = False + def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None): + self.server = kwargs.get('server') + Thread.__init__(self, group, target, name, args, kwargs, Verbose) + + def run(self): + try: + if os.path.exists(SOCKET_FILE): + os.remove(SOCKET_FILE) + else: + dir = os.path.dirname(SOCKET_FILE) + if os.path.exists(dir): + if not os.path.isdir(dir): + log("The socket file path (%s) exists, but is not a directory. No local connections will be available" % dir) + return + else: + # make sure the directory is accessible only to this + # user, as socket files derive their permissions from + # the directory that contains them + os.chmod(dir, 0700) + elif not os.path.exists(dir): + os.makedirs(dir, 0700) + except OSError: + pass + self.s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.s.bind(SOCKET_FILE) + self.s.listen(5) + while not self.terminated: + try: + conn, addr = self.s.accept() + vv("received local connection") + data = "" + while "\n" not in data: + data += conn.recv(2048) + try: + new_key = AesKey.Read(data.strip()) + found = False + for key in self.server.key_list: + try: + new_key.Decrypt(key.Encrypt("foo")) + found = True + break + except: + pass + if not found: + vv("adding new key to the key list") + self.server.key_list.append(new_key) + conn.sendall("OK\n") + else: + vv("key already exists in the key list, ignoring") + conn.sendall("EXISTS\n") + + # update the last event time so the server doesn't + # shutdown sooner than expected for new cliets + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + except Exception, e: + vv("key loaded locally was invalid, ignoring (%s)" % e) + conn.sendall("BADKEY\n") + finally: + try: + conn.close() + except: + pass + except: + pass + + def terminate(self): + self.terminated = True + self.s.shutdown(socket.SHUT_RDWR) + self.s.close() + +class ThreadWithReturnValue(Thread): def __init__(self, group=None, target=None, name=None, args=(), kwargs={}, Verbose=None): Thread.__init__(self, group, target, name, args, kwargs, Verbose) self._return = None @@ -190,24 +288,41 @@ class ThreadWithReturnValue(Thread): return self._return class ThreadedTCPServer(SocketServer.ThreadingTCPServer): - def __init__(self, server_address, RequestHandlerClass, module, password, timeout): + key_list = [] + last_event = datetime.now() + last_event_lock = Lock() + def __init__(self, server_address, RequestHandlerClass, module, password, timeout, use_ipv6=False): self.module = module - self.key = AesKey.Read(password) + self.key_list.append(AesKey.Read(password)) self.allow_reuse_address = True self.timeout = timeout + + if use_ipv6: + self.address_family = socket.AF_INET6 + + if self.module.params.get('multi_key', False): + vv("starting thread to handle local connections for multiple keys") + self.local_thread = LocalSocketThread(kwargs=dict(server=self)) + self.local_thread.start() + SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) -class ThreadedTCPV6Server(SocketServer.ThreadingTCPServer): - def __init__(self, server_address, RequestHandlerClass, module, password, timeout): - self.module = module - self.address_family = socket.AF_INET6 - self.key = AesKey.Read(password) - self.allow_reuse_address = True - self.timeout = timeout - SocketServer.ThreadingTCPServer.__init__(self, server_address, RequestHandlerClass) + def shutdown(self): + self.local_thread.terminate() + self.running = False + SocketServer.ThreadingTCPServer.shutdown(self) class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): + # the key to use for this connection + active_key = None + def send_data(self, data): + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + packed_len = struct.pack('!Q', len(data)) return self.request.sendall(packed_len + data) @@ -216,23 +331,40 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): data = "" vvvv("in recv_data(), waiting for the header") while len(data) < header_len: - d = self.request.recv(header_len - len(data)) - if not d: - vvv("received nothing, bailing out") + try: + d = self.request.recv(header_len - len(data)) + if not d: + vvv("received nothing, bailing out") + return None + data += d + except: + # probably got a connection reset + vvvv("exception received while waiting for recv(), returning None") return None - data += d vvvv("in recv_data(), got the header, unpacking") data_len = struct.unpack('!Q',data[:header_len])[0] data = data[header_len:] vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) while len(data) < data_len: - d = self.request.recv(data_len - len(data)) - if not d: - vvv("received nothing, bailing out") + try: + d = self.request.recv(data_len - len(data)) + if not d: + vvv("received nothing, bailing out") + return None + data += d + vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) + except: + # probably got a connection reset + vvvv("exception received while waiting for recv(), returning None") return None - data += d - vvvv("data received so far (expecting %d): %d" % (data_len,len(data))) vvvv("received all of the data, returning") + + try: + self.server.last_event_lock.acquire() + self.server.last_event = datetime.now() + finally: + self.server.last_event_lock.release() + return data def handle(self): @@ -243,18 +375,26 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): if not data: vvvv("received nothing back from recv_data(), breaking out") break - try: - vvvv("got data, decrypting") - data = self.server.key.Decrypt(data) - vvvv("decryption done") - except: - vv("bad decrypt, skipping...") - data2 = json.dumps(dict(rc=1)) - data2 = self.server.key.Encrypt(data2) - self.send_data(data2) - return + vvvv("got data, decrypting") + if not self.active_key: + for key in self.server.key_list: + try: + data = key.Decrypt(data) + self.active_key = key + break + except: + pass + else: + vv("bad decrypt, exiting the connection handler") + return + else: + try: + data = self.active_key.Decrypt(data) + except: + vv("bad decrypt, exiting the connection handler") + return - vvvv("loading json from the data") + vvvv("decryption done, loading json from the data") data = json.loads(data) mode = data['mode'] @@ -270,7 +410,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): last_pong = datetime.now() vvvv("command still running, sending keepalive packet") data2 = json.dumps(dict(pong=True)) - data2 = self.server.key.Encrypt(data2) + data2 = self.active_key.Encrypt(data2) self.send_data(data2) time.sleep(0.1) response = twrv._return @@ -286,8 +426,9 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): response = self.validate_user(data) vvvv("response result is %s" % str(response)) - data2 = json.dumps(response) - data2 = self.server.key.Encrypt(data2) + json_response = json.dumps(response) + vvvv("dumped json is %s" % json_response) + data2 = self.active_key.Encrypt(json_response) vvvv("sending the response back to the controller") self.send_data(data2) vvvv("done sending the response") @@ -299,9 +440,10 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): tb = traceback.format_exc() log("encountered an unhandled exception in the handle() function") log("error was:\n%s" % tb) - data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function")) - data2 = self.server.key.Encrypt(data2) - self.send_data(data2) + if self.active_key: + data2 = json.dumps(dict(rc=1, failed=True, msg="unhandled error in the handle() function")) + data2 = self.active_key.Encrypt(data2) + self.send_data(data2) def validate_user(self, data): if 'username' not in data: @@ -362,7 +504,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): last = True data = dict(data=base64.b64encode(data), last=last) data = json.dumps(data) - data = self.server.key.Encrypt(data) + data = self.active_key.Encrypt(data) if self.send_data(data): return dict(failed=True, stderr="failed to send data") @@ -371,7 +513,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): if not response: log("failed to get a response, aborting") return dict(failed=True, stderr="Failed to get a response from %s" % self.host) - response = self.server.key.Decrypt(response) + response = self.active_key.Decrypt(response) response = json.loads(response) if response.get('failed',False): @@ -394,7 +536,7 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): final_path = None if 'user' in data and data.get('user') != getpass.getuser(): - vv("the target user doesn't match this user, we'll move the file into place via sudo") + vvv("the target user doesn't match this user, we'll move the file into place via sudo") tmp_path = os.path.expanduser('~/.ansible/tmp/') if not os.path.exists(tmp_path): try: @@ -415,14 +557,14 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): bytes += len(out) out_fd.write(out) response = json.dumps(dict()) - response = self.server.key.Encrypt(response) + response = self.active_key.Encrypt(response) self.send_data(response) if data['last']: break data = self.recv_data() if not data: raise "" - data = self.server.key.Decrypt(data) + data = self.active_key.Decrypt(data) data = json.loads(data) except: out_fd.close() @@ -438,27 +580,45 @@ class ThreadedTCPRequestHandler(SocketServer.BaseRequestHandler): self.server.module.atomic_move(out_path, final_path) return dict() -def daemonize(module, password, port, timeout, minutes, ipv6): +def daemonize(module, password, port, timeout, minutes, use_ipv6, pid_file): try: - daemonize_self(module, password, port, minutes) + daemonize_self(module, password, port, minutes, pid_file) - def catcher(signum, _): - module.exit_json(msg='timer expired') + def timer_handler(signum, _): + try: + server.last_event_lock.acquire() + td = datetime.now() - server.last_event + # older python timedelta objects don't have total_seconds(), + # so we use the formula from the docs to calculate it + total_seconds = (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / 10**6 + if total_seconds >= minutes * 60: + log("server has been idle longer than the timeout, shutting down") + server.running = False + server.shutdown() + else: + # reschedule the check + vvvv("daemon idle for %d seconds (timeout=%d)" % (total_seconds,minutes*60)) + signal.alarm(30) + except: + pass + finally: + server.last_event_lock.release() - signal.signal(signal.SIGALRM, catcher) - signal.setitimer(signal.ITIMER_REAL, 60 * minutes) + signal.signal(signal.SIGALRM, timer_handler) + signal.alarm(30) tries = 5 while tries > 0: try: - if ipv6: - server = ThreadedTCPV6Server(("::", port), ThreadedTCPRequestHandler, module, password, timeout) + if use_ipv6: + address = ("::", port) else: - server = ThreadedTCPServer(("0.0.0.0", port), ThreadedTCPRequestHandler, module, password, timeout) + address = ("0.0.0.0", port) + server = ThreadedTCPServer(address, ThreadedTCPRequestHandler, module, password, timeout, use_ipv6=use_ipv6) server.allow_reuse_address = True break - except: - vv("Failed to create the TCP server (tries left = %d)" % tries) + except Exception, e: + vv("Failed to create the TCP server (tries left = %d) (error: %s) " % (tries,e)) tries -= 1 time.sleep(0.2) @@ -466,8 +626,20 @@ def daemonize(module, password, port, timeout, minutes, ipv6): vv("Maximum number of attempts to create the TCP server reached, bailing out") raise Exception("max # of attempts to serve reached") - vv("serving!") - server.serve_forever(poll_interval=0.1) + # run the server in a separate thread to make signal handling work + server_thread = Thread(target=server.serve_forever, kwargs=dict(poll_interval=0.1)) + server_thread.start() + server.running = True + + v("serving!") + while server.running: + time.sleep(1) + + # wait for the thread to exit fully + server_thread.join() + + v("server thread terminated, exiting!") + sys.exit(0) except Exception, e: tb = traceback.format_exc() log("exception caught, exiting accelerated mode: %s\n%s" % (e, tb)) @@ -479,6 +651,7 @@ def main(): argument_spec = dict( port=dict(required=False, default=5099), ipv6=dict(required=False, default=False, type='bool'), + multi_key=dict(required=False, default=False, type='bool'), timeout=dict(required=False, default=300), password=dict(required=True), minutes=dict(required=False, default=30), @@ -493,14 +666,62 @@ def main(): minutes = int(module.params['minutes']) debug = int(module.params['debug']) ipv6 = module.params['ipv6'] + multi_key = module.params['multi_key'] if not HAS_KEYCZAR: module.fail_json(msg="keyczar is not installed (on the remote side)") DEBUG_LEVEL=debug + pid_file = get_pid_location(module) - daemonize(module, password, port, timeout, minutes, ipv6) + daemon_pid = None + daemon_running = False + if os.path.exists(pid_file): + try: + daemon_pid = int(open(pid_file).read()) + try: + # sending signal 0 doesn't do anything to the + # process, other than tell the calling program + # whether other signals can be sent + os.kill(daemon_pid, 0) + except OSError, e: + if e.errno == errno.EPERM: + # no permissions means the pid is probably + # running, but as a different user, so fail + module.fail_json(msg="the accelerate daemon appears to be running as a different user that this user cannot access (pid=%d)" % daemon_pid) + else: + daemon_running = True + except ValueError: + # invalid pid file, unlink it - otherwise we don't care + try: + os.unlink(pid_file) + except: + pass + + if daemon_running and multi_key: + # try to connect to the file socket for the daemon if it exists + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + s.connect(SOCKET_FILE) + s.sendall(password + '\n') + data = "" + while '\n' not in data: + data += s.recv(2048) + res = data.strip() + except: + module.fail_json(msg="failed to connect to the local socket file") + finally: + try: + s.close() + except: + pass + + if res in ("OK", "EXISTS"): + module.exit_json(msg="transferred new key to the existing daemon") + else: + module.fail_json(msg="could not transfer new key: %s" % data.strip()) + else: + # try to start up the daemon + daemonize(module, password, port, timeout, minutes, ipv6, pid_file) -# import module snippets -from ansible.module_utils.basic import * main()