From 8ef78b1cf8798611ff6bba5a8962e978a405670f Mon Sep 17 00:00:00 2001 From: James Cammarata Date: Fri, 2 Oct 2015 00:35:22 -0400 Subject: [PATCH] Fixing accelerated connection plugin --- lib/ansible/executor/task_executor.py | 38 ++- lib/ansible/playbook/play_context.py | 10 + lib/ansible/plugins/connection/accelerate.py | 290 +++++++------------ lib/ansible/utils/encrypt.py | 77 +++++ 4 files changed, 229 insertions(+), 186 deletions(-) diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 0e34938ed6b..cc997a5e719 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -19,6 +19,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import base64 import json import pipes import subprocess @@ -33,6 +34,7 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVar from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task from ansible.template import Templar +from ansible.utils.encrypt import key_for_hostname from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.unicode import to_unicode from ansible.vars.unsafe_proxy import UnsafeProxy @@ -309,7 +311,7 @@ class TaskExecutor: return dict(include=include_file, include_variables=include_variables) # get the connection and the handler for this execution - self._connection = self._get_connection(variables) + self._connection = self._get_connection(variables=variables, templar=templar) self._connection.set_host_overrides(host=self._host) self._handler = self._get_action_handler(connection=self._connection, templar=templar) @@ -466,7 +468,7 @@ class TaskExecutor: else: return async_result - def _get_connection(self, variables): + def _get_connection(self, variables, templar): ''' Reads the connection property for the host, and returns the correct connection object from the list of connection plugins @@ -513,6 +515,38 @@ class TaskExecutor: if not connection: raise AnsibleError("the connection plugin '%s' was not found" % conn_type) + if self._play_context.accelerate: + # launch the accelerated daemon here + ssh_connection = connection + handler = self._shared_loader_obj.action_loader.get( + 'normal', + task=self._task, + connection=ssh_connection, + play_context=self._play_context, + loader=self._loader, + templar=templar, + shared_loader_obj=self._shared_loader_obj, + ) + + key = key_for_hostname(self._play_context.remote_addr) + accelerate_args = dict( + password=base64.b64encode(key.__str__()), + port=self._play_context.accelerate_port, + minutes=C.ACCELERATE_DAEMON_TIMEOUT, + ipv6=self._play_context.accelerate_ipv6, + debug=self._play_context.verbosity, + ) + + connection = self._shared_loader_obj.connection_loader.get('accelerate', self._play_context, self._new_stdin) + if not connection: + raise AnsibleError("the connection plugin '%s' was not found" % conn_type) + + try: + connection._connect() + except AnsibleConnectionFailure: + res = handler._execute_module(module_name='accelerate', module_args=accelerate_args, task_vars=variables, delete_remote_tmp=False) + connection._connect() + return connection def _get_action_handler(self, connection, templar): diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py index f85e442822e..dfccf7345b4 100644 --- a/lib/ansible/playbook/play_context.py +++ b/lib/ansible/playbook/play_context.py @@ -56,6 +56,7 @@ MAGIC_VARIABLE_MAPPING = dict( remote_addr = ('ansible_ssh_host', 'ansible_host'), remote_user = ('ansible_ssh_user', 'ansible_user'), port = ('ansible_ssh_port', 'ansible_port'), + accelerate_port = ('ansible_accelerate_port',), password = ('ansible_ssh_pass', 'ansible_password'), private_key_file = ('ansible_ssh_private_key_file', 'ansible_private_key_file'), pipelining = ('ansible_ssh_pipelining', 'ansible_pipelining'), @@ -142,6 +143,9 @@ class PlayContext(Base): _ssh_extra_args = FieldAttribute(isa='string') _connection_lockfd= FieldAttribute(isa='int') _pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_SSH_PIPELINING) + _accelerate = FieldAttribute(isa='bool', default=False) + _accelerate_ipv6 = FieldAttribute(isa='bool', default=False, always_post_validate=True) + _accelerate_port = FieldAttribute(isa='int', default=C.ACCELERATE_PORT, always_post_validate=True) # privilege escalation fields _become = FieldAttribute(isa='bool') @@ -199,6 +203,12 @@ class PlayContext(Base): the play class. ''' + # special handling for accelerated mode, as it is set in a separate + # play option from the connection parameter + self.accelerate = play.accelerate + self.accelerate_ipv6 = play.accelerate_ipv6 + self.accelerate_port = play.accelerate_port + if play.connection: self.connection = play.connection diff --git a/lib/ansible/plugins/connection/accelerate.py b/lib/ansible/plugins/connection/accelerate.py index dfff616703c..4b4b068da4f 100644 --- a/lib/ansible/plugins/connection/accelerate.py +++ b/lib/ansible/plugins/connection/accelerate.py @@ -18,19 +18,20 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import base64 import json import os -import base64 import socket import struct import time -from ansible.callbacks import vvv, vvvv -from ansible.errors import AnsibleError, AnsibleFileNotFound -from . import ConnectionBase -from .ssh import Connection as SSHConnection -from .paramiko_ssh import Connection as ParamikoConnection -from ansible import utils -from ansible import constants + +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleFileNotFound, AnsibleConnectionFailure +from ansible.parsing.utils.jsonify import jsonify +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.connection.ssh import Connection as SSHConnection +from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoConnection +from ansible.utils.encrypt import key_for_hostname, keyczar_encrypt, keyczar_decrypt # the chunk size to read and send, assuming mtu 1500 and # leaving room for base64 (+33%) encoding and header (8 bytes) @@ -42,127 +43,50 @@ CHUNK_SIZE=1044*20 class Connection(ConnectionBase): ''' raw socket accelerated connection ''' - def __init__(self, runner, host, port, user, password, private_key_file, *args, **kwargs): + transport = 'accelerate' + has_pipelining = False + become_methods = frozenset(C.BECOME_METHODS).difference(['runas']) + + def __init__(self, *args, **kwargs): + + super(Connection, self).__init__(*args, **kwargs) - self.runner = runner - self.host = host - self.context = None self.conn = None - self.user = user - self.key = utils.key_for_hostname(host) - self.port = port[0] - self.accport = port[1] - self.is_connected = False - self.has_pipelining = False - self.become_methods_supported=['sudo'] + self.key = key_for_hostname(self._play_context.remote_addr) - if not self.port: - self.port = constants.DEFAULT_REMOTE_PORT - elif not isinstance(self.port, int): - self.port = int(self.port) - - if not self.accport: - self.accport = constants.ACCELERATE_PORT - elif not isinstance(self.accport, int): - self.accport = int(self.accport) - - if self.runner.original_transport == "paramiko": - self.ssh = ParamikoConnection( - runner=self.runner, - host=self.host, - port=self.port, - user=self.user, - password=password, - private_key_file=private_key_file - ) - else: - self.ssh = SSHConnection( - runner=self.runner, - host=self.host, - port=self.port, - user=self.user, - password=password, - private_key_file=private_key_file - ) - - if not getattr(self.ssh, 'shell', None): - self.ssh.shell = utils.plugins.shell_loader.get('sh') - - # attempt to work around shared-memory funness - if getattr(self.runner, 'aes_keys', None): - utils.AES_KEYS = self.runner.aes_keys - - @property - def transport(self): - """String used to identify this Connection class from other classes""" - return 'accelerate' - - def _execute_accelerate_module(self): - args = "password=%s port=%s minutes=%d debug=%d ipv6=%s" % ( - base64.b64encode(self.key.__str__()), - str(self.accport), - constants.ACCELERATE_DAEMON_TIMEOUT, - int(utils.VERBOSITY), - self.runner.accelerate_ipv6, - ) - if constants.ACCELERATE_MULTI_KEY: - args += " multi_key=yes" - inject = dict(password=self.key) - if getattr(self.runner, 'accelerate_inventory_host', False): - inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.runner.accelerate_inventory_host)) - else: - inject = utils.combine_vars(inject, self.runner.inventory.get_variables(self.host)) - vvvv("attempting to start up the accelerate daemon...") - self.ssh.connect() - tmp_path = self.runner._make_tmp_path(self.ssh) - return self.runner._execute_module(self.ssh, tmp_path, 'accelerate', args, inject=inject) - - def connect(self, allow_ssh=True): + def _connect(self): ''' activates the connection object ''' - try: - if not self.is_connected: - wrong_user = False - tries = 3 - self.conn = socket.socket() - self.conn.settimeout(constants.ACCELERATE_CONNECT_TIMEOUT) - vvvv("attempting connection to %s via the accelerated port %d" % (self.host,self.accport)) - while tries > 0: - try: - self.conn.connect((self.host,self.accport)) - break - except socket.error: - vvvv("connection to %s failed, retrying..." % self.host) - time.sleep(0.1) - tries -= 1 - if tries == 0: - vvv("Could not connect via the accelerated connection, exceeded # of tries") - raise AnsibleError("FAILED") - elif wrong_user: - vvv("Restarting daemon with a different remote_user") - raise AnsibleError("WRONG_USER") + if not self._connected: + wrong_user = False + tries = 3 + self.conn = socket.socket() + self.conn.settimeout(C.ACCELERATE_CONNECT_TIMEOUT) + self._display.vvvv("attempting connection to %s via the accelerated port %d" % (self._play_context.remote_addr,self._play_context.accelerate_port)) + while tries > 0: + try: + self.conn.connect((self._play_context.remote_addr,self._play_context.accelerate_port)) + break + except socket.error: + self._display.vvvv("connection to %s failed, retrying..." % self._play_context.remote_addr) + time.sleep(0.1) + tries -= 1 + if tries == 0: + self._display.vvv("Could not connect via the accelerated connection, exceeded # of tries") + raise AnsibleConnectionFailure("Failed to connect to %s on the accelerated port %s" % (self._play_context.remote_addr, self._play_context.accelerate_port)) + elif wrong_user: + self._display.vvv("Restarting daemon with a different remote_user") + raise AnsibleError("The accelerated daemon was started on the remote with a different user") - self.conn.settimeout(constants.ACCELERATE_TIMEOUT) - if not self.validate_user(): - # the accelerated daemon was started with a - # different remote_user. The above command - # should have caused the accelerate daemon to - # shutdown, so we'll reconnect. - wrong_user = True + self.conn.settimeout(C.ACCELERATE_TIMEOUT) + if not self.validate_user(): + # the accelerated daemon was started with a + # different remote_user. The above command + # should have caused the accelerate daemon to + # shutdown, so we'll reconnect. + wrong_user = True - except AnsibleError as e: - if allow_ssh: - if "WRONG_USER" in e: - vvv("Switching users, waiting for the daemon on %s to shutdown completely..." % self.host) - time.sleep(5) - vvv("Falling back to ssh to startup accelerated mode") - res = self._execute_accelerate_module() - if not res.is_successful(): - raise AnsibleError("Failed to launch the accelerated daemon on %s (reason: %s)" % (self.host,res.result.get('msg'))) - return self.connect(allow_ssh=False) - else: - raise AnsibleError("Failed to connect to %s:%s" % (self.host,self.accport)) - self.is_connected = True + self._connected = True return self def send_data(self, data): @@ -173,25 +97,25 @@ class Connection(ConnectionBase): header_len = 8 # size of a packed unsigned long long data = b"" try: - vvvv("%s: in recv_data(), waiting for the header" % self.host) + self._display.vvvv("%s: in recv_data(), waiting for the header" % self._play_context.remote_addr) while len(data) < header_len: d = self.conn.recv(header_len - len(data)) if not d: - vvvv("%s: received nothing, bailing out" % self.host) + self._display.vvvv("%s: received nothing, bailing out" % self._play_context.remote_addr) return None data += d - vvvv("%s: got the header, unpacking" % self.host) + self._display.vvvv("%s: got the header, unpacking" % self._play_context.remote_addr) data_len = struct.unpack('!Q',data[:header_len])[0] data = data[header_len:] - vvvv("%s: data received so far (expecting %d): %d" % (self.host,data_len,len(data))) + self._display.vvvv("%s: data received so far (expecting %d): %d" % (self._play_context.remote_addr,data_len,len(data))) while len(data) < data_len: d = self.conn.recv(data_len - len(data)) if not d: - vvvv("%s: received nothing, bailing out" % self.host) + self._display.vvvv("%s: received nothing, bailing out" % self._play_context.remote_addr) return None - vvvv("%s: received %d bytes" % (self.host, len(d))) + self._display.vvvv("%s: received %d bytes" % (self._play_context.remote_addr, len(d))) data += d - vvvv("%s: received all of the data, returning" % self.host) + self._display.vvvv("%s: received all of the data, returning" % self._play_context.remote_addr) return data except socket.timeout: raise AnsibleError("timed out while waiting to receive data") @@ -203,32 +127,32 @@ class Connection(ConnectionBase): daemon to exit if they don't match ''' - vvvv("%s: sending request for validate_user" % self.host) + self._display.vvvv("%s: sending request for validate_user" % self._play_context.remote_addr) data = dict( mode='validate_user', - username=self.user, + username=self._play_context.remote_user, ) - data = utils.jsonify(data) - data = utils.encrypt(self.key, data) + data = jsonify(data) + data = keyczar_encrypt(self.key, data) if self.send_data(data): - raise AnsibleError("Failed to send command to %s" % self.host) + raise AnsibleError("Failed to send command to %s" % self._play_context.remote_addr) - vvvv("%s: waiting for validate_user response" % self.host) + self._display.vvvv("%s: waiting for validate_user response" % self._play_context.remote_addr) while True: # we loop here while waiting for the response, because a # long running command may cause us to receive keepalive packets # ({"pong":"true"}) rather than the response we want. response = self.recv_data() if not response: - raise AnsibleError("Failed to get a response from %s" % self.host) - response = utils.decrypt(self.key, response) - response = utils.parse_json(response) + raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr) + response = keyczar_decrypt(self.key, response) + response = json.loads(response) if "pong" in response: # it's a keepalive, go back to waiting - vvvv("%s: received a keepalive packet" % self.host) + self._display.vvvv("%s: received a keepalive packet" % self._play_context.remote_addr) continue else: - vvvv("%s: received the validate_user response: %s" % (self.host, response)) + self._display.vvvv("%s: received the validate_user response: %s" % (self._play_context.remote_addr, response)) break if response.get('failed'): @@ -236,32 +160,30 @@ class Connection(ConnectionBase): else: return response.get('rc') == 0 - def exec_command(self, cmd, become_user=None, sudoable=False, executable='/bin/sh', in_data=None): + def exec_command(self, cmd, in_data=None, sudoable=True): + ''' run a command on the remote host ''' - if sudoable and self.runner.become and self.runner.become_method not in self.become_methods_supported: - raise errors.AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method) + super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) + + # FIXME: + #if sudoable and self..become and self.runner.become_method not in self.become_methods_supported: + # raise AnsibleError("Internal Error: this module does not support running commands via %s" % self.runner.become_method) if in_data: raise AnsibleError("Internal Error: this module does not support optimized module pipelining") - if executable == "": - executable = constants.DEFAULT_EXECUTABLE - - if self.runner.become and sudoable: - cmd, prompt, success_key = utils.make_become_cmd(cmd, become_user, executable, self.runner.become_method, '', self.runner.become_exe) - - vvv("EXEC COMMAND %s" % cmd) + self._display.vvv("EXEC COMMAND %s" % cmd) data = dict( mode='command', cmd=cmd, - executable=executable, + executable=C.DEFAULT_EXECUTABLE, ) - data = utils.jsonify(data) - data = utils.encrypt(self.key, data) + data = jsonify(data) + data = keyczar_encrypt(self.key, data) if self.send_data(data): - raise AnsibleError("Failed to send command to %s" % self.host) + raise AnsibleError("Failed to send command to %s" % self._play_context.remote_addr) while True: # we loop here while waiting for the response, because a @@ -269,15 +191,15 @@ class Connection(ConnectionBase): # ({"pong":"true"}) rather than the response we want. response = self.recv_data() if not response: - raise AnsibleError("Failed to get a response from %s" % self.host) - response = utils.decrypt(self.key, response) - response = utils.parse_json(response) + raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr) + response = keyczar_decrypt(self.key, response) + response = json.loads(response) if "pong" in response: # it's a keepalive, go back to waiting - vvvv("%s: received a keepalive packet" % self.host) + self._display.vvvv("%s: received a keepalive packet" % self._play_context.remote_addr) continue else: - vvvv("%s: received the response" % self.host) + self._display.vvvv("%s: received the response" % self._play_context.remote_addr) break return (response.get('rc', None), response.get('stdout', ''), response.get('stderr', '')) @@ -285,7 +207,7 @@ class Connection(ConnectionBase): def put_file(self, in_path, out_path): ''' transfer a file from local to remote ''' - vvv("PUT %s TO %s" % (in_path, out_path), host=self.host) + self._display.vvv("PUT %s TO %s" % (in_path, out_path), host=self._play_context.remote_addr) if not os.path.exists(in_path): raise AnsibleFileNotFound("file or module does not exist: %s" % in_path) @@ -293,51 +215,51 @@ class Connection(ConnectionBase): fd = file(in_path, 'rb') fstat = os.stat(in_path) try: - vvv("PUT file is %d bytes" % fstat.st_size) + self._display.vvv("PUT file is %d bytes" % fstat.st_size) last = False while fd.tell() <= fstat.st_size and not last: - vvvv("file position currently %ld, file size is %ld" % (fd.tell(), fstat.st_size)) + self._display.vvvv("file position currently %ld, file size is %ld" % (fd.tell(), fstat.st_size)) data = fd.read(CHUNK_SIZE) if fd.tell() >= fstat.st_size: last = True data = dict(mode='put', data=base64.b64encode(data), out_path=out_path, last=last) - if self.runner.become: - data['user'] = self.runner.become_user - data = utils.jsonify(data) - data = utils.encrypt(self.key, data) + if self._play_context.become: + data['user'] = self._play_context.become_user + data = jsonify(data) + data = keyczar_encrypt(self.key, data) if self.send_data(data): - raise AnsibleError("failed to send the file to %s" % self.host) + raise AnsibleError("failed to send the file to %s" % self._play_context.remote_addr) response = self.recv_data() if not response: - raise AnsibleError("Failed to get a response from %s" % self.host) - response = utils.decrypt(self.key, response) - response = utils.parse_json(response) + raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr) + response = keyczar_decrypt(self.key, response) + response = json.loads(response) if response.get('failed',False): raise AnsibleError("failed to put the file in the requested location") finally: fd.close() - vvvv("waiting for final response after PUT") + self._display.vvvv("waiting for final response after PUT") response = self.recv_data() if not response: - raise AnsibleError("Failed to get a response from %s" % self.host) - response = utils.decrypt(self.key, response) - response = utils.parse_json(response) + raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr) + response = keyczar_decrypt(self.key, response) + response = json.loads(response) if response.get('failed',False): raise AnsibleError("failed to put the file in the requested location") def fetch_file(self, in_path, out_path): ''' save a remote file to the specified path ''' - vvv("FETCH %s TO %s" % (in_path, out_path), host=self.host) + self._display.vvv("FETCH %s TO %s" % (in_path, out_path), host=self._play_context.remote_addr) data = dict(mode='fetch', in_path=in_path) - data = utils.jsonify(data) - data = utils.encrypt(self.key, data) + data = jsonify(data) + data = keyczar_encrypt(self.key, data) if self.send_data(data): - raise AnsibleError("failed to initiate the file fetch with %s" % self.host) + raise AnsibleError("failed to initiate the file fetch with %s" % self._play_context.remote_addr) fh = open(out_path, "w") try: @@ -345,9 +267,9 @@ class Connection(ConnectionBase): while True: response = self.recv_data() if not response: - raise AnsibleError("Failed to get a response from %s" % self.host) - response = utils.decrypt(self.key, response) - response = utils.parse_json(response) + raise AnsibleError("Failed to get a response from %s" % self._play_context.remote_addr) + response = keyczar_decrypt(self.key, response) + response = json.loads(response) if response.get('failed', False): raise AnsibleError("Error during file fetch, aborting") out = base64.b64decode(response['data']) @@ -355,8 +277,8 @@ class Connection(ConnectionBase): bytes += len(out) # send an empty response back to signify we # received the last chunk without errors - data = utils.jsonify(dict()) - data = utils.encrypt(self.key, data) + data = jsonify(dict()) + data = keyczar_encrypt(self.key, data) if self.send_data(data): raise AnsibleError("failed to send ack during file fetch") if response.get('last', False): @@ -367,7 +289,7 @@ class Connection(ConnectionBase): # point in the future or we may just have the put/fetch # operations not send back a final response at all response = self.recv_data() - vvv("FETCH wrote %d bytes to %s" % (bytes, out_path)) + self._display.vvv("FETCH wrote %d bytes to %s" % (bytes, out_path)) fh.close() def close(self): diff --git a/lib/ansible/utils/encrypt.py b/lib/ansible/utils/encrypt.py index 5138dbef705..da006fb773a 100644 --- a/lib/ansible/utils/encrypt.py +++ b/lib/ansible/utils/encrypt.py @@ -18,6 +18,11 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import os +import stat +import time +import warnings + PASSLIB_AVAILABLE = False try: import passlib.hash @@ -25,6 +30,34 @@ try: except: pass +KEYCZAR_AVAILABLE=False +try: + try: + # some versions of pycrypto may not have this? + from Crypto.pct_warnings import PowmInsecureWarning + except ImportError: + PowmInsecureWarning = RuntimeWarning + + with warnings.catch_warnings(record=True) as warning_handler: + warnings.simplefilter("error", PowmInsecureWarning) + try: + import keyczar.errors as key_errors + from keyczar.keys import AesKey + except PowmInsecureWarning: + system_warning( + "The version of gmp you have installed has a known issue regarding " + \ + "timing vulnerabilities when used with pycrypto. " + \ + "If possible, you should update it (i.e. yum update gmp)." + ) + warnings.resetwarnings() + warnings.simplefilter("ignore") + import keyczar.errors as key_errors + from keyczar.keys import AesKey + KEYCZAR_AVAILABLE=True +except ImportError: + pass + +from ansible import constants as C from ansible.errors import AnsibleError __all__ = ['do_encrypt'] @@ -47,3 +80,47 @@ def do_encrypt(result, encrypt, salt_size=None, salt=None): return result +def key_for_hostname(hostname): + # fireball mode is an implementation of ansible firing up zeromq via SSH + # to use no persistent daemons or key management + + if not KEYCZAR_AVAILABLE: + raise AnsibleError("python-keyczar must be installed on the control machine to use accelerated modes") + + key_path = os.path.expanduser(C.ACCELERATE_KEYS_DIR) + if not os.path.exists(key_path): + os.makedirs(key_path, mode=0700) + os.chmod(key_path, int(C.ACCELERATE_KEYS_DIR_PERMS, 8)) + elif not os.path.isdir(key_path): + raise AnsibleError('ACCELERATE_KEYS_DIR is not a directory.') + + if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_DIR_PERMS, 8): + raise AnsibleError('Incorrect permissions on the private key directory. Use `chmod 0%o %s` to correct this issue, and make sure any of the keys files contained within that directory are set to 0%o' % (int(C.ACCELERATE_KEYS_DIR_PERMS, 8), C.ACCELERATE_KEYS_DIR, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))) + + key_path = os.path.join(key_path, hostname) + + # use new AES keys every 2 hours, which means fireball must not allow running for longer either + if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2): + key = AesKey.Generate(size=256) + fd = os.open(key_path, os.O_WRONLY | os.O_CREAT, int(C.ACCELERATE_KEYS_FILE_PERMS, 8)) + fh = os.fdopen(fd, 'w') + fh.write(str(key)) + fh.close() + return key + else: + if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_FILE_PERMS, 8): + raise AnsibleError('Incorrect permissions on the key file for this host. Use `chmod 0%o %s` to correct this issue.' % (int(C.ACCELERATE_KEYS_FILE_PERMS, 8), key_path)) + fh = open(key_path) + key = AesKey.Read(fh.read()) + fh.close() + return key + +def keyczar_encrypt(key, msg): + return key.Encrypt(msg.encode('utf-8')) + +def keyczar_decrypt(key, msg): + try: + return key.Decrypt(msg) + except key_errors.InvalidSignatureError: + raise AnsibleError("decryption failed") +