Merge pull request #10754 from invenia/devel

Python 2/3 compatibility fixes to parsing in v2.
This commit is contained in:
Brian Coca 2015-04-20 11:00:35 -04:00
commit cef93db0a7
8 changed files with 193 additions and 129 deletions

View file

@ -20,7 +20,6 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from six import iteritems, string_types from six import iteritems, string_types
from types import NoneType
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.plugins import module_loader from ansible.plugins import module_loader
@ -165,7 +164,7 @@ class ModuleArgsParser:
# form is like: local_action: copy src=a dest=b ... pretty common # form is like: local_action: copy src=a dest=b ... pretty common
check_raw = action in ('command', 'shell', 'script') check_raw = action in ('command', 'shell', 'script')
args = parse_kv(thing, check_raw=check_raw) args = parse_kv(thing, check_raw=check_raw)
elif isinstance(thing, NoneType): elif thing is None:
# this can happen with modules which take no params, like ping: # this can happen with modules which take no params, like ping:
args = None args = None
else: else:

View file

@ -22,6 +22,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import sys
import os import os
import shlex import shlex
import shutil import shutil
@ -35,7 +36,10 @@ from hashlib import sha256
from hashlib import md5 from hashlib import md5
from binascii import hexlify from binascii import hexlify
from binascii import unhexlify from binascii import unhexlify
from six import binary_type, byte2int, PY2, text_type
from ansible import constants as C from ansible import constants as C
from ansible.utils.unicode import to_unicode, to_bytes
try: try:
from Crypto.Hash import SHA256, HMAC from Crypto.Hash import SHA256, HMAC
@ -60,15 +64,16 @@ except ImportError:
# AES IMPORTS # AES IMPORTS
try: try:
from Crypto.Cipher import AES as AES from Crypto.Cipher import AES as AES
HAS_AES = True HAS_AES = True
except ImportError: except ImportError:
HAS_AES = False HAS_AES = False
CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto" CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the one installed on your platform. You may fix this with OS-specific commands such as: yum install python-devel; rpm -e --nodeps python-crypto; pip install pycrypto"
HEADER='$ANSIBLE_VAULT' HEADER=u'$ANSIBLE_VAULT'
CIPHER_WHITELIST=['AES', 'AES256'] CIPHER_WHITELIST=['AES', 'AES256']
class VaultLib(object): class VaultLib(object):
def __init__(self, password): def __init__(self, password):
@ -76,26 +81,28 @@ class VaultLib(object):
self.cipher_name = None self.cipher_name = None
self.version = '1.1' self.version = '1.1'
def is_encrypted(self, data): def is_encrypted(self, data):
data = to_unicode(data)
if data.startswith(HEADER): if data.startswith(HEADER):
return True return True
else: else:
return False return False
def encrypt(self, data): def encrypt(self, data):
data = to_unicode(data)
if self.is_encrypted(data): if self.is_encrypted(data):
raise errors.AnsibleError("data is already encrypted") raise errors.AnsibleError("data is already encrypted")
if not self.cipher_name: if not self.cipher_name:
self.cipher_name = "AES256" self.cipher_name = "AES256"
#raise errors.AnsibleError("the cipher must be set before encrypting data") # raise errors.AnsibleError("the cipher must be set before encrypting data")
if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST:
cipher = globals()['Vault' + self.cipher_name] cipher = globals()['Vault' + self.cipher_name]
this_cipher = cipher() this_cipher = cipher()
else: else:
raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) raise errors.AnsibleError("{} cipher could not be found".format(self.cipher_name))
""" """
# combine sha + data # combine sha + data
@ -106,11 +113,13 @@ class VaultLib(object):
# encrypt sha + data # encrypt sha + data
enc_data = this_cipher.encrypt(data, self.password) enc_data = this_cipher.encrypt(data, self.password)
# add header # add header
tmp_data = self._add_header(enc_data) tmp_data = self._add_header(enc_data)
return tmp_data return tmp_data
def decrypt(self, data): def decrypt(self, data):
data = to_bytes(data)
if self.password is None: if self.password is None:
raise errors.AnsibleError("A vault password must be specified to decrypt data") raise errors.AnsibleError("A vault password must be specified to decrypt data")
@ -121,48 +130,47 @@ class VaultLib(object):
data = self._split_header(data) data = self._split_header(data)
# create the cipher object # create the cipher object
if 'Vault' + self.cipher_name in globals() and self.cipher_name in CIPHER_WHITELIST: ciphername = to_unicode(self.cipher_name)
cipher = globals()['Vault' + self.cipher_name] if 'Vault' + ciphername in globals() and ciphername in CIPHER_WHITELIST:
cipher = globals()['Vault' + ciphername]
this_cipher = cipher() this_cipher = cipher()
else: else:
raise errors.AnsibleError("%s cipher could not be found" % self.cipher_name) raise errors.AnsibleError("{} cipher could not be found".format(ciphername))
# try to unencrypt data # try to unencrypt data
data = this_cipher.decrypt(data, self.password) data = this_cipher.decrypt(data, self.password)
if data is None: if data is None:
raise errors.AnsibleError("Decryption failed") raise errors.AnsibleError("Decryption failed")
return data return data
def _add_header(self, data): def _add_header(self, data):
# combine header and encrypted data in 80 char columns # combine header and encrypted data in 80 char columns
#tmpdata = hexlify(data) #tmpdata = hexlify(data)
tmpdata = [data[i:i+80] for i in range(0, len(data), 80)] tmpdata = [to_bytes(data[i:i+80]) for i in range(0, len(data), 80)]
if not self.cipher_name: if not self.cipher_name:
raise errors.AnsibleError("the cipher must be set before adding a header") raise errors.AnsibleError("the cipher must be set before adding a header")
dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher_name + "\n" dirty_data = to_bytes(HEADER + ";" + self.version + ";" + self.cipher_name + "\n")
for l in tmpdata: for l in tmpdata:
dirty_data += l + '\n' dirty_data += l + b'\n'
return dirty_data return dirty_data
def _split_header(self, data): def _split_header(self, data):
# used by decrypt # used by decrypt
tmpdata = data.split('\n') tmpdata = data.split(b'\n')
tmpheader = tmpdata[0].strip().split(';') tmpheader = tmpdata[0].strip().split(b';')
self.version = str(tmpheader[1].strip()) self.version = to_unicode(tmpheader[1].strip())
self.cipher_name = str(tmpheader[2].strip()) self.cipher_name = to_unicode(tmpheader[2].strip())
clean_data = '\n'.join(tmpdata[1:]) clean_data = b'\n'.join(tmpdata[1:])
""" """
# strip out newline, join, unhex # strip out newline, join, unhex
clean_data = [ x.strip() for x in clean_data ] clean_data = [ x.strip() for x in clean_data ]
clean_data = unhexlify(''.join(clean_data)) clean_data = unhexlify(''.join(clean_data))
""" """
@ -176,9 +184,9 @@ class VaultLib(object):
pass pass
class VaultEditor(object): class VaultEditor(object):
# uses helper methods for write_file(self, filename, data) # uses helper methods for write_file(self, filename, data)
# to write a file so that code isn't duplicated for simple # to write a file so that code isn't duplicated for simple
# file I/O, ditto read_file(self, filename) and launch_editor(self, filename) # file I/O, ditto read_file(self, filename) and launch_editor(self, filename)
# ... "Don't Repeat Yourself", etc. # ... "Don't Repeat Yourself", etc.
def __init__(self, cipher_name, password, filename): def __init__(self, cipher_name, password, filename):
@ -302,7 +310,7 @@ class VaultEditor(object):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or not HAS_HASH:
raise errors.AnsibleError(CRYPTO_UPGRADE) raise errors.AnsibleError(CRYPTO_UPGRADE)
# decrypt # decrypt
tmpdata = self.read_data(self.filename) tmpdata = self.read_data(self.filename)
this_vault = VaultLib(self.password) this_vault = VaultLib(self.password)
dec_data = this_vault.decrypt(tmpdata) dec_data = this_vault.decrypt(tmpdata)
@ -324,10 +332,10 @@ class VaultEditor(object):
return tmpdata return tmpdata
def write_data(self, data, filename): def write_data(self, data, filename):
if os.path.isfile(filename): if os.path.isfile(filename):
os.remove(filename) os.remove(filename)
f = open(filename, "wb") f = open(filename, "wb")
f.write(data) f.write(to_bytes(data))
f.close() f.close()
def shuffle_files(self, src, dest): def shuffle_files(self, src, dest):
@ -369,9 +377,10 @@ class VaultAES(object):
""" Create a key and an initialization vector """ """ Create a key and an initialization vector """
d = d_i = '' d = d_i = b''
while len(d) < key_length + iv_length: while len(d) < key_length + iv_length:
d_i = md5(d_i + password + salt).digest() text = "{}{}{}".format(d_i, password, salt)
d_i = md5(to_bytes(text)).digest()
d += d_i d += d_i
key = d[:key_length] key = d[:key_length]
@ -385,28 +394,29 @@ class VaultAES(object):
# combine sha + data # combine sha + data
this_sha = sha256(data).hexdigest() this_sha = sha256(to_bytes(data)).hexdigest()
tmp_data = this_sha + "\n" + data tmp_data = this_sha + "\n" + data
in_file = BytesIO(tmp_data) in_file = BytesIO(to_bytes(tmp_data))
in_file.seek(0) in_file.seek(0)
out_file = BytesIO() out_file = BytesIO()
bs = AES.block_size bs = AES.block_size
# Get a block of random data. EL does not have Crypto.Random.new() # Get a block of random data. EL does not have Crypto.Random.new()
# so os.urandom is used for cross platform purposes # so os.urandom is used for cross platform purposes
salt = os.urandom(bs - len('Salted__')) salt = os.urandom(bs - len('Salted__'))
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES.new(key, AES.MODE_CBC, iv) cipher = AES.new(key, AES.MODE_CBC, iv)
out_file.write('Salted__' + salt) full = to_bytes(b'Salted__' + salt)
out_file.write(full)
finished = False finished = False
while not finished: while not finished:
chunk = in_file.read(1024 * bs) chunk = in_file.read(1024 * bs)
if len(chunk) == 0 or len(chunk) % bs != 0: if len(chunk) == 0 or len(chunk) % bs != 0:
padding_length = (bs - len(chunk) % bs) or bs padding_length = (bs - len(chunk) % bs) or bs
chunk += padding_length * chr(padding_length) chunk += to_bytes(padding_length * chr(padding_length))
finished = True finished = True
out_file.write(cipher.encrypt(chunk)) out_file.write(cipher.encrypt(chunk))
@ -416,14 +426,14 @@ class VaultAES(object):
return tmp_data return tmp_data
def decrypt(self, data, password, key_length=32): def decrypt(self, data, password, key_length=32):
""" Read encrypted data from in_file and write decrypted to out_file """ """ Read encrypted data from in_file and write decrypted to out_file """
# http://stackoverflow.com/a/14989032 # http://stackoverflow.com/a/14989032
data = ''.join(data.split('\n')) data = b''.join(data.split(b'\n'))
data = unhexlify(data) data = unhexlify(data)
in_file = BytesIO(data) in_file = BytesIO(data)
@ -431,41 +441,49 @@ class VaultAES(object):
out_file = BytesIO() out_file = BytesIO()
bs = AES.block_size bs = AES.block_size
salt = in_file.read(bs)[len('Salted__'):] tmpsalt = in_file.read(bs)
salt = tmpsalt[len('Salted__'):]
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES.new(key, AES.MODE_CBC, iv) cipher = AES.new(key, AES.MODE_CBC, iv)
next_chunk = '' next_chunk = b''
finished = False finished = False
while not finished: while not finished:
chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs)) chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs))
if len(next_chunk) == 0: if len(next_chunk) == 0:
padding_length = ord(chunk[-1]) if PY2:
padding_length = ord(chunk[-1])
else:
padding_length = chunk[-1]
chunk = chunk[:-padding_length] chunk = chunk[:-padding_length]
finished = True finished = True
out_file.write(chunk) out_file.write(chunk)
out_file.flush()
# reset the stream pointer to the beginning # reset the stream pointer to the beginning
out_file.seek(0) out_file.seek(0)
new_data = out_file.read() out_data = out_file.read()
out_file.close()
new_data = to_unicode(out_data)
# split out sha and verify decryption # split out sha and verify decryption
split_data = new_data.split("\n") split_data = new_data.split("\n")
this_sha = split_data[0] this_sha = split_data[0]
this_data = '\n'.join(split_data[1:]) this_data = '\n'.join(split_data[1:])
test_sha = sha256(this_data).hexdigest() test_sha = sha256(to_bytes(this_data)).hexdigest()
if this_sha != test_sha: if this_sha != test_sha:
raise errors.AnsibleError("Decryption failed") raise errors.AnsibleError("Decryption failed")
#return out_file.read()
return this_data return this_data
class VaultAES256(object): class VaultAES256(object):
""" """
Vault implementation using AES-CTR with an HMAC-SHA256 authentication code. Vault implementation using AES-CTR with an HMAC-SHA256 authentication code.
Keys are derived using PBKDF2 Keys are derived using PBKDF2
""" """
@ -481,7 +499,7 @@ class VaultAES256(object):
keylength = 32 keylength = 32
# match the size used for counter.new to avoid extra work # match the size used for counter.new to avoid extra work
ivlength = 16 ivlength = 16
hash_function = SHA256 hash_function = SHA256
@ -489,7 +507,7 @@ class VaultAES256(object):
pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest() pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest()
derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength, derivedkey = PBKDF2(password, salt, dkLen=(2 * keylength) + ivlength,
count=10000, prf=pbkdf2_prf) count=10000, prf=pbkdf2_prf)
key1 = derivedkey[:keylength] key1 = derivedkey[:keylength]
@ -523,28 +541,28 @@ class VaultAES256(object):
cipher = AES.new(key1, AES.MODE_CTR, counter=ctr) cipher = AES.new(key1, AES.MODE_CTR, counter=ctr)
# ENCRYPT PADDED DATA # ENCRYPT PADDED DATA
cryptedData = cipher.encrypt(data) cryptedData = cipher.encrypt(data)
# COMBINE SALT, DIGEST AND DATA # COMBINE SALT, DIGEST AND DATA
hmac = HMAC.new(key2, cryptedData, SHA256) hmac = HMAC.new(key2, cryptedData, SHA256)
message = "%s\n%s\n%s" % ( hexlify(salt), hmac.hexdigest(), hexlify(cryptedData) ) message = b''.join([hexlify(salt), b"\n", to_bytes(hmac.hexdigest()), b"\n", hexlify(cryptedData)])
message = hexlify(message) message = hexlify(message)
return message return message
def decrypt(self, data, password): def decrypt(self, data, password):
# SPLIT SALT, DIGEST, AND DATA # SPLIT SALT, DIGEST, AND DATA
data = ''.join(data.split("\n")) data = b''.join(data.split(b"\n"))
data = unhexlify(data) data = unhexlify(data)
salt, cryptedHmac, cryptedData = data.split("\n", 2) salt, cryptedHmac, cryptedData = data.split(b"\n", 2)
salt = unhexlify(salt) salt = unhexlify(salt)
cryptedData = unhexlify(cryptedData) cryptedData = unhexlify(cryptedData)
key1, key2, iv = self.gen_key_initctr(password, salt) key1, key2, iv = self.gen_key_initctr(password, salt)
# EXIT EARLY IF DIGEST DOESN'T MATCH # EXIT EARLY IF DIGEST DOESN'T MATCH
hmacDecrypt = HMAC.new(key2, cryptedData, SHA256) hmacDecrypt = HMAC.new(key2, cryptedData, SHA256)
if not self.is_equal(cryptedHmac, hmacDecrypt.hexdigest()): if not self.is_equal(cryptedHmac, to_bytes(hmacDecrypt.hexdigest())):
return None return None
# SET THE COUNTER AND THE CIPHER # SET THE COUNTER AND THE CIPHER
@ -555,19 +573,31 @@ class VaultAES256(object):
decryptedData = cipher.decrypt(cryptedData) decryptedData = cipher.decrypt(cryptedData)
# UNPAD DATA # UNPAD DATA
padding_length = ord(decryptedData[-1]) try:
padding_length = ord(decryptedData[-1])
except TypeError:
padding_length = decryptedData[-1]
decryptedData = decryptedData[:-padding_length] decryptedData = decryptedData[:-padding_length]
return decryptedData return to_unicode(decryptedData)
def is_equal(self, a, b): def is_equal(self, a, b):
"""
Comparing 2 byte arrrays in constant time
to avoid timing attacks.
It would be nice if there was a library for this but
hey.
"""
# http://codahale.com/a-lesson-in-timing-attacks/ # http://codahale.com/a-lesson-in-timing-attacks/
if len(a) != len(b): if len(a) != len(b):
return False return False
result = 0 result = 0
for x, y in zip(a, b): for x, y in zip(a, b):
result |= ord(x) ^ ord(y) if PY2:
return result == 0 result |= ord(x) ^ ord(y)
else:
result |= x ^ y
return result == 0

View file

@ -19,14 +19,17 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
class AnsibleBaseYAMLObject: from six import text_type
class AnsibleBaseYAMLObject(object):
''' '''
the base class used to sub-class python built-in objects the base class used to sub-class python built-in objects
so that we can add attributes to them during yaml parsing so that we can add attributes to them during yaml parsing
''' '''
_data_source = None _data_source = None
_line_number = 0 _line_number = 0
_column_number = 0 _column_number = 0
def _get_ansible_position(self): def _get_ansible_position(self):
@ -36,21 +39,27 @@ class AnsibleBaseYAMLObject:
try: try:
(src, line, col) = obj (src, line, col) = obj
except (TypeError, ValueError): except (TypeError, ValueError):
raise AssertionError('ansible_pos can only be set with a tuple/list of three values: source, line number, column number') raise AssertionError(
self._data_source = src 'ansible_pos can only be set with a tuple/list '
self._line_number = line 'of three values: source, line number, column number'
)
self._data_source = src
self._line_number = line
self._column_number = col self._column_number = col
ansible_pos = property(_get_ansible_position, _set_ansible_position) ansible_pos = property(_get_ansible_position, _set_ansible_position)
class AnsibleMapping(AnsibleBaseYAMLObject, dict): class AnsibleMapping(AnsibleBaseYAMLObject, dict):
''' sub class for dictionaries ''' ''' sub class for dictionaries '''
pass pass
class AnsibleUnicode(AnsibleBaseYAMLObject, unicode):
class AnsibleUnicode(AnsibleBaseYAMLObject, text_type):
''' sub class for unicode objects ''' ''' sub class for unicode objects '''
pass pass
class AnsibleSequence(AnsibleBaseYAMLObject, list): class AnsibleSequence(AnsibleBaseYAMLObject, list):
''' sub class for lists ''' ''' sub class for lists '''
pass pass

View file

@ -19,6 +19,8 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from six import string_types, text_type, binary_type, PY3
# to_bytes and to_unicode were written by Toshio Kuratomi for the # to_bytes and to_unicode were written by Toshio Kuratomi for the
# python-kitchen library https://pypi.python.org/pypi/kitchen # python-kitchen library https://pypi.python.org/pypi/kitchen
# They are licensed in kitchen under the terms of the GPLv2+ # They are licensed in kitchen under the terms of the GPLv2+
@ -35,6 +37,9 @@ _LATIN1_ALIASES = frozenset(('latin-1', 'LATIN-1', 'latin1', 'LATIN1',
# EXCEPTION_CONVERTERS is defined below due to using to_unicode # EXCEPTION_CONVERTERS is defined below due to using to_unicode
if PY3:
basestring = (str, bytes)
def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None): def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
'''Convert an object into a :class:`unicode` string '''Convert an object into a :class:`unicode` string
@ -89,12 +94,12 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
# Could use isbasestring/isunicode here but we want this code to be as # Could use isbasestring/isunicode here but we want this code to be as
# fast as possible # fast as possible
if isinstance(obj, basestring): if isinstance(obj, basestring):
if isinstance(obj, unicode): if isinstance(obj, text_type):
return obj return obj
if encoding in _UTF8_ALIASES: if encoding in _UTF8_ALIASES:
return unicode(obj, 'utf-8', errors) return text_type(obj, 'utf-8', errors)
if encoding in _LATIN1_ALIASES: if encoding in _LATIN1_ALIASES:
return unicode(obj, 'latin-1', errors) return text_type(obj, 'latin-1', errors)
return obj.decode(encoding, errors) return obj.decode(encoding, errors)
if not nonstring: if not nonstring:
@ -110,19 +115,19 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
simple = None simple = None
if not simple: if not simple:
try: try:
simple = str(obj) simple = text_type(obj)
except UnicodeError: except UnicodeError:
try: try:
simple = obj.__str__() simple = obj.__str__()
except (UnicodeError, AttributeError): except (UnicodeError, AttributeError):
simple = u'' simple = u''
if isinstance(simple, str): if isinstance(simple, binary_type):
return unicode(simple, encoding, errors) return text_type(simple, encoding, errors)
return simple return simple
elif nonstring in ('repr', 'strict'): elif nonstring in ('repr', 'strict'):
obj_repr = repr(obj) obj_repr = repr(obj)
if isinstance(obj_repr, str): if isinstance(obj_repr, binary_type):
obj_repr = unicode(obj_repr, encoding, errors) obj_repr = text_type(obj_repr, encoding, errors)
if nonstring == 'repr': if nonstring == 'repr':
return obj_repr return obj_repr
raise TypeError('to_unicode was given "%(obj)s" which is neither' raise TypeError('to_unicode was given "%(obj)s" which is neither'
@ -198,19 +203,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None):
# Could use isbasestring, isbytestring here but we want this to be as fast # Could use isbasestring, isbytestring here but we want this to be as fast
# as possible # as possible
if isinstance(obj, basestring): if isinstance(obj, basestring):
if isinstance(obj, str): if isinstance(obj, binary_type):
return obj return obj
return obj.encode(encoding, errors) return obj.encode(encoding, errors)
if not nonstring: if not nonstring:
nonstring = 'simplerepr' nonstring = 'simplerepr'
if nonstring == 'empty': if nonstring == 'empty':
return '' return b''
elif nonstring == 'passthru': elif nonstring == 'passthru':
return obj return obj
elif nonstring == 'simplerepr': elif nonstring == 'simplerepr':
try: try:
simple = str(obj) simple = binary_type(obj)
except UnicodeError: except UnicodeError:
try: try:
simple = obj.__str__() simple = obj.__str__()
@ -220,19 +225,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None):
try: try:
simple = obj.__unicode__() simple = obj.__unicode__()
except (AttributeError, UnicodeError): except (AttributeError, UnicodeError):
simple = '' simple = b''
if isinstance(simple, unicode): if isinstance(simple, text_type):
simple = simple.encode(encoding, 'replace') simple = simple.encode(encoding, 'replace')
return simple return simple
elif nonstring in ('repr', 'strict'): elif nonstring in ('repr', 'strict'):
try: try:
obj_repr = obj.__repr__() obj_repr = obj.__repr__()
except (AttributeError, UnicodeError): except (AttributeError, UnicodeError):
obj_repr = '' obj_repr = b''
if isinstance(obj_repr, unicode): if isinstance(obj_repr, text_type):
obj_repr = obj_repr.encode(encoding, errors) obj_repr = obj_repr.encode(encoding, errors)
else: else:
obj_repr = str(obj_repr) obj_repr = binary_type(obj_repr)
if nonstring == 'repr': if nonstring == 'repr':
return obj_repr return obj_repr
raise TypeError('to_bytes was given "%(obj)s" which is neither' raise TypeError('to_bytes was given "%(obj)s" which is neither'

View file

@ -19,6 +19,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from six import PY2
from yaml.scanner import ScannerError from yaml.scanner import ScannerError
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
@ -79,6 +80,11 @@ class TestDataLoaderWithVault(unittest.TestCase):
3135306561356164310a343937653834643433343734653137383339323330626437313562306630 3135306561356164310a343937653834643433343734653137383339323330626437313562306630
3035 3035
""" """
with patch('__builtin__.open', mock_open(read_data=vaulted_data)): if PY2:
builtins_name = '__builtin__'
else:
builtins_name = 'builtins'
with patch(builtins_name + '.open', mock_open(read_data=vaulted_data)):
output = self._loader.load_from_file('dummy_vault.txt') output = self._loader.load_from_file('dummy_vault.txt')
self.assertEqual(output, dict(foo='bar')) self.assertEqual(output, dict(foo='bar'))

View file

@ -24,11 +24,14 @@ import os
import shutil import shutil
import time import time
import tempfile import tempfile
import six
from binascii import unhexlify from binascii import unhexlify
from binascii import hexlify from binascii import hexlify
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.utils.unicode import to_bytes, to_unicode
from ansible import errors from ansible import errors
from ansible.parsing.vault import VaultLib from ansible.parsing.vault import VaultLib
@ -63,13 +66,13 @@ class TestVaultLib(unittest.TestCase):
'decrypt', 'decrypt',
'_add_header', '_add_header',
'_split_header',] '_split_header',]
for slot in slots: for slot in slots:
assert hasattr(v, slot), "VaultLib is missing the %s method" % slot assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
def test_is_encrypted(self): def test_is_encrypted(self):
v = VaultLib(None) v = VaultLib(None)
assert not v.is_encrypted("foobar"), "encryption check on plaintext failed" assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed"
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible") data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
assert v.is_encrypted(data), "encryption check on headered text failed" assert v.is_encrypted(data), "encryption check on headered text failed"
def test_add_header(self): def test_add_header(self):
@ -77,22 +80,22 @@ class TestVaultLib(unittest.TestCase):
v.cipher_name = "TEST" v.cipher_name = "TEST"
sensitive_data = "ansible" sensitive_data = "ansible"
data = v._add_header(sensitive_data) data = v._add_header(sensitive_data)
lines = data.split('\n') lines = data.split(b'\n')
assert len(lines) > 1, "failed to properly add header" assert len(lines) > 1, "failed to properly add header"
header = lines[0] header = to_unicode(lines[0])
assert header.endswith(';TEST'), "header does end with cipher name" assert header.endswith(';TEST'), "header does end with cipher name"
header_parts = header.split(';') header_parts = header.split(';')
assert len(header_parts) == 3, "header has the wrong number of parts" assert len(header_parts) == 3, "header has the wrong number of parts"
assert header_parts[0] == '$ANSIBLE_VAULT', "header does not start with $ANSIBLE_VAULT" assert header_parts[0] == '$ANSIBLE_VAULT', "header does not start with $ANSIBLE_VAULT"
assert header_parts[1] == v.version, "header version is incorrect" assert header_parts[1] == v.version, "header version is incorrect"
assert header_parts[2] == 'TEST', "header does end with cipher name" assert header_parts[2] == 'TEST', "header does end with cipher name"
def test_split_header(self): def test_split_header(self):
v = VaultLib('ansible') v = VaultLib('ansible')
data = "$ANSIBLE_VAULT;9.9;TEST\nansible" data = b"$ANSIBLE_VAULT;9.9;TEST\nansible"
rdata = v._split_header(data) rdata = v._split_header(data)
lines = rdata.split('\n') lines = rdata.split(b'\n')
assert lines[0] == "ansible" assert lines[0] == b"ansible"
assert v.cipher_name == 'TEST', "cipher name was not set" assert v.cipher_name == 'TEST', "cipher name was not set"
assert v.version == "9.9" assert v.version == "9.9"
@ -100,11 +103,11 @@ class TestVaultLib(unittest.TestCase):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
v.cipher_name = 'AES' v.cipher_name = u'AES'
enc_data = v.encrypt("foobar") enc_data = v.encrypt("foobar")
dec_data = v.decrypt(enc_data) dec_data = v.decrypt(enc_data)
assert enc_data != "foobar", "encryption failed" assert enc_data != "foobar", "encryption failed"
assert dec_data == "foobar", "decryption failed" assert dec_data == "foobar", "decryption failed"
def test_encrypt_decrypt_aes256(self): def test_encrypt_decrypt_aes256(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
@ -114,20 +117,20 @@ class TestVaultLib(unittest.TestCase):
enc_data = v.encrypt("foobar") enc_data = v.encrypt("foobar")
dec_data = v.decrypt(enc_data) dec_data = v.decrypt(enc_data)
assert enc_data != "foobar", "encryption failed" assert enc_data != "foobar", "encryption failed"
assert dec_data == "foobar", "decryption failed" assert dec_data == "foobar", "decryption failed"
def test_encrypt_encrypted(self): def test_encrypt_encrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest raise SkipTest
v = VaultLib('ansible') v = VaultLib('ansible')
v.cipher_name = 'AES' v.cipher_name = 'AES'
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible") data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible"))
error_hit = False error_hit = False
try: try:
enc_data = v.encrypt(data) enc_data = v.encrypt(data)
except errors.AnsibleError as e: except errors.AnsibleError as e:
error_hit = True error_hit = True
assert error_hit, "No error was thrown when trying to encrypt data with a header" assert error_hit, "No error was thrown when trying to encrypt data with a header"
def test_decrypt_decrypted(self): def test_decrypt_decrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
@ -139,7 +142,7 @@ class TestVaultLib(unittest.TestCase):
dec_data = v.decrypt(data) dec_data = v.decrypt(data)
except errors.AnsibleError as e: except errors.AnsibleError as e:
error_hit = True error_hit = True
assert error_hit, "No error was thrown when trying to decrypt data without a header" assert error_hit, "No error was thrown when trying to decrypt data without a header"
def test_cipher_not_set(self): def test_cipher_not_set(self):
# not setting the cipher should default to AES256 # not setting the cipher should default to AES256
@ -152,5 +155,5 @@ class TestVaultLib(unittest.TestCase):
enc_data = v.encrypt(data) enc_data = v.encrypt(data)
except errors.AnsibleError as e: except errors.AnsibleError as e:
error_hit = True error_hit = True
assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set" assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set"
assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name

View file

@ -21,6 +21,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
#!/usr/bin/env python #!/usr/bin/env python
import sys
import getpass import getpass
import os import os
import shutil import shutil
@ -32,6 +33,7 @@ from nose.plugins.skip import SkipTest
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch from ansible.compat.tests.mock import patch
from ansible.utils.unicode import to_bytes, to_unicode
from ansible import errors from ansible import errors
from ansible.parsing.vault import VaultLib from ansible.parsing.vault import VaultLib
@ -88,12 +90,12 @@ class TestVaultEditor(unittest.TestCase):
'read_data', 'read_data',
'write_data', 'write_data',
'shuffle_files'] 'shuffle_files']
for slot in slots: for slot in slots:
assert hasattr(v, slot), "VaultLib is missing the %s method" % slot assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
@patch.object(VaultEditor, '_editor_shell_command') @patch.object(VaultEditor, '_editor_shell_command')
def test_create_file(self, mock_editor_shell_command): def test_create_file(self, mock_editor_shell_command):
def sc_side_effect(filename): def sc_side_effect(filename):
return ['touch', filename] return ['touch', filename]
mock_editor_shell_command.side_effect = sc_side_effect mock_editor_shell_command.side_effect = sc_side_effect
@ -107,12 +109,16 @@ class TestVaultEditor(unittest.TestCase):
self.assertTrue(os.path.exists(tmp_file.name)) self.assertTrue(os.path.exists(tmp_file.name))
def test_decrypt_1_0(self): def test_decrypt_1_0(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: """
Skip testing decrypting 1.0 files if we don't have access to AES, KDF or
Counter, or we are running on python3 since VaultAES hasn't been backported.
"""
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3':
raise SkipTest raise SkipTest
v10_file = tempfile.NamedTemporaryFile(delete=False) v10_file = tempfile.NamedTemporaryFile(delete=False)
with v10_file as f: with v10_file as f:
f.write(v10_data) f.write(to_bytes(v10_data))
ve = VaultEditor(None, "ansible", v10_file.name) ve = VaultEditor(None, "ansible", v10_file.name)
@ -125,13 +131,13 @@ class TestVaultEditor(unittest.TestCase):
# verify decrypted content # verify decrypted content
f = open(v10_file.name, "rb") f = open(v10_file.name, "rb")
fdata = f.read() fdata = to_unicode(f.read())
f.close() f.close()
os.unlink(v10_file.name) os.unlink(v10_file.name)
assert error_hit == False, "error decrypting 1.0 file" assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
def test_decrypt_1_1(self): def test_decrypt_1_1(self):
@ -140,7 +146,7 @@ class TestVaultEditor(unittest.TestCase):
v11_file = tempfile.NamedTemporaryFile(delete=False) v11_file = tempfile.NamedTemporaryFile(delete=False)
with v11_file as f: with v11_file as f:
f.write(v11_data) f.write(to_bytes(v11_data))
ve = VaultEditor(None, "ansible", v11_file.name) ve = VaultEditor(None, "ansible", v11_file.name)
@ -153,28 +159,32 @@ class TestVaultEditor(unittest.TestCase):
# verify decrypted content # verify decrypted content
f = open(v11_file.name, "rb") f = open(v11_file.name, "rb")
fdata = f.read() fdata = to_unicode(f.read())
f.close() f.close()
os.unlink(v11_file.name) os.unlink(v11_file.name)
assert error_hit == False, "error decrypting 1.0 file" assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
def test_rekey_migration(self): def test_rekey_migration(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: """
Skip testing rekeying files if we don't have access to AES, KDF or
Counter, or we are running on python3 since VaultAES hasn't been backported.
"""
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3':
raise SkipTest raise SkipTest
v10_file = tempfile.NamedTemporaryFile(delete=False) v10_file = tempfile.NamedTemporaryFile(delete=False)
with v10_file as f: with v10_file as f:
f.write(v10_data) f.write(to_bytes(v10_data))
ve = VaultEditor(None, "ansible", v10_file.name) ve = VaultEditor(None, "ansible", v10_file.name)
# make sure the password functions for the cipher # make sure the password functions for the cipher
error_hit = False error_hit = False
try: try:
ve.rekey_file('ansible2') ve.rekey_file('ansible2')
except errors.AnsibleError as e: except errors.AnsibleError as e:
error_hit = True error_hit = True
@ -184,7 +194,7 @@ class TestVaultEditor(unittest.TestCase):
fdata = f.read() fdata = f.read()
f.close() f.close()
assert error_hit == False, "error rekeying 1.0 file to 1.1" assert error_hit == False, "error rekeying 1.0 file to 1.1"
# ensure filedata can be decrypted, is 1.1 and is AES256 # ensure filedata can be decrypted, is 1.1 and is AES256
vl = VaultLib("ansible2") vl = VaultLib("ansible2")
@ -198,7 +208,7 @@ class TestVaultEditor(unittest.TestCase):
os.unlink(v10_file.name) os.unlink(v10_file.name)
assert vl.cipher_name == "AES256", "wrong cipher name set after rekey: %s" % vl.cipher_name assert vl.cipher_name == "AES256", "wrong cipher name set after rekey: %s" % vl.cipher_name
assert error_hit == False, "error decrypting migrated 1.0 file" assert error_hit == False, "error decrypting migrated 1.0 file"
assert dec_data.strip() == "foo", "incorrect decryption of rekeyed/migrated file: %s" % dec_data assert dec_data.strip() == "foo", "incorrect decryption of rekeyed/migrated file: %s" % dec_data

View file

@ -20,6 +20,7 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from six import text_type, binary_type
from six.moves import StringIO from six.moves import StringIO
from collections import Sequence, Set, Mapping from collections import Sequence, Set, Mapping
@ -28,6 +29,7 @@ from ansible.compat.tests.mock import patch
from ansible.parsing.yaml.loader import AnsibleLoader from ansible.parsing.yaml.loader import AnsibleLoader
class TestAnsibleLoaderBasic(unittest.TestCase): class TestAnsibleLoaderBasic(unittest.TestCase):
def setUp(self): def setUp(self):
@ -52,7 +54,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
loader = AnsibleLoader(stream, 'myfile.yml') loader = AnsibleLoader(stream, 'myfile.yml')
data = loader.get_single_data() data = loader.get_single_data()
self.assertEqual(data, u'Ansible') self.assertEqual(data, u'Ansible')
self.assertIsInstance(data, unicode) self.assertIsInstance(data, text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
@ -63,7 +65,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
loader = AnsibleLoader(stream, 'myfile.yml') loader = AnsibleLoader(stream, 'myfile.yml')
data = loader.get_single_data() data = loader.get_single_data()
self.assertEqual(data, u'Cafè Eñyei') self.assertEqual(data, u'Cafè Eñyei')
self.assertIsInstance(data, unicode) self.assertIsInstance(data, text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
@ -76,8 +78,8 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
data = loader.get_single_data() data = loader.get_single_data()
self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'}) self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'})
self.assertEqual(len(data), 2) self.assertEqual(len(data), 2)
self.assertIsInstance(data.keys()[0], unicode) self.assertIsInstance(list(data.keys())[0], text_type)
self.assertIsInstance(data.values()[0], unicode) self.assertIsInstance(list(data.values())[0], text_type)
# Beginning of the first key # Beginning of the first key
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
@ -94,7 +96,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
data = loader.get_single_data() data = loader.get_single_data()
self.assertEqual(data, [u'a', u'b']) self.assertEqual(data, [u'a', u'b'])
self.assertEqual(len(data), 2) self.assertEqual(len(data), 2)
self.assertIsInstance(data[0], unicode) self.assertIsInstance(data[0], text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17)) self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
@ -204,10 +206,10 @@ class TestAnsibleLoaderPlay(unittest.TestCase):
def walk(self, data): def walk(self, data):
# Make sure there's no str in the data # Make sure there's no str in the data
self.assertNotIsInstance(data, str) self.assertNotIsInstance(data, binary_type)
# Descend into various container types # Descend into various container types
if isinstance(data, unicode): if isinstance(data, text_type):
# strings are a sequence so we have to be explicit here # strings are a sequence so we have to be explicit here
return return
elif isinstance(data, (Sequence, Set)): elif isinstance(data, (Sequence, Set)):