Merge pull request #13992 from electrofelix/accelerate-race

Fix race in accelerate connection plugin
This commit is contained in:
Brian Coca 2016-02-11 22:00:38 -05:00
commit d99955596e
2 changed files with 44 additions and 16 deletions

View file

@ -21,6 +21,7 @@ import base64
import socket
import struct
import time
import threading
from ansible.callbacks import vvv, vvvv
from ansible.errors import AnsibleError, AnsibleFileNotFound
from ansible.runner.connection_plugins.ssh import Connection as SSHConnection
@ -35,6 +36,8 @@ from ansible import constants
# multiple of the value to speed up file reads.
CHUNK_SIZE=1044*20
_LOCK = threading.Lock()
class Connection(object):
''' raw socket accelerated connection '''
@ -111,6 +114,15 @@ class Connection(object):
def connect(self, allow_ssh=True):
''' activates the connection object '''
# ensure only one fork tries to setup the connection, in case the
# first task for multiple hosts is delegated to the same host.
if not self.is_connected:
with(_LOCK):
return self._connect(allow_ssh)
return self
def _connect(self, allow_ssh=True):
try:
if not self.is_connected:
wrong_user = False
@ -150,7 +162,7 @@ class Connection(object):
res = self._execute_accelerate_module()
if not res.is_successful():
raise AnsibleError("Failed to launch the accelerated daemon on %s (reason: %s)" % (self.host,res.result.get('msg')))
return self.connect(allow_ssh=False)
return self._connect(allow_ssh=False)
else:
raise AnsibleError("Failed to connect to %s:%s" % (self.host,self.accport))
self.is_connected = True

View file

@ -49,6 +49,8 @@ import traceback
import getpass
import subprocess
import contextlib
import threading
import tempfile
from vault import VaultLib
@ -62,6 +64,7 @@ LOOKUP_REGEX = re.compile(r'lookup\s*\(')
PRINT_CODE_REGEX = re.compile(r'(?:{[{%]|[%}]})')
CODE_REGEX = re.compile(r'(?:{%|%})')
_LOCK = threading.Lock()
try:
# simplejson can be much faster if it's available
@ -127,8 +130,15 @@ def key_for_hostname(hostname):
key_path = os.path.expanduser(C.ACCELERATE_KEYS_DIR)
if not os.path.exists(key_path):
os.makedirs(key_path, mode=0700)
os.chmod(key_path, int(C.ACCELERATE_KEYS_DIR_PERMS, 8))
# avoid race with multiple forks trying to create paths on host
# but limit when locking is needed to creation only
with(_LOCK):
if not os.path.exists(key_path):
# use a temp directory and rename to ensure the directory
# searched for only appears after permissions applied.
tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(key_path))
os.chmod(tmp_dir, int(C.ACCELERATE_KEYS_DIR_PERMS, 8))
os.rename(tmp_dir, key_path)
elif not os.path.isdir(key_path):
raise errors.AnsibleError('ACCELERATE_KEYS_DIR is not a directory.')
@ -139,19 +149,25 @@ def key_for_hostname(hostname):
# use new AES keys every 2 hours, which means fireball must not allow running for longer either
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate(size=256)
fd = os.open(key_path, os.O_WRONLY | os.O_CREAT, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))
fh = os.fdopen(fd, 'w')
# avoid race with multiple forks trying to create key
# but limit when locking is needed to creation only
with(_LOCK):
if not os.path.exists(key_path) or (time.time() - os.path.getmtime(key_path) > 60*60*2):
key = AesKey.Generate()
# use temp file to ensure file only appears once it has
# desired contents and permissions
with tempfile.NamedTemporaryFile(mode='w', dir=os.path.dirname(key_path), delete=False) as fh:
tmp_key_path = fh.name
fh.write(str(key))
fh.close()
os.chmod(tmp_key_path, int(C.ACCELERATE_KEYS_FILE_PERMS, 8))
os.rename(tmp_key_path, key_path)
return key
else:
if stat.S_IMODE(os.stat(key_path).st_mode) != int(C.ACCELERATE_KEYS_FILE_PERMS, 8):
raise errors.AnsibleError('Incorrect permissions on the key file for this host. Use `chmod 0%o %s` to correct this issue.' % (int(C.ACCELERATE_KEYS_FILE_PERMS, 8), key_path))
fh = open(key_path)
key = AesKey.Read(fh.read())
fh.close()
return key
with open(key_path) as fh:
return AesKey.Read(fh.read())
def encrypt(key, msg):
return key.Encrypt(msg.encode('utf-8'))