From e079758b316996093df1f5f5812b1ff639b78412 Mon Sep 17 00:00:00 2001 From: Felix Fontein Date: Mon, 8 Apr 2019 13:59:55 +0200 Subject: [PATCH] Move refactoring steps from #54635 to own PR. (#54690) --- .../54690-openssl_certificate-assertonly.yml | 2 + .../modules/crypto/openssl_certificate.py | 935 ++++++++++-------- 2 files changed, 507 insertions(+), 430 deletions(-) create mode 100644 changelogs/fragments/54690-openssl_certificate-assertonly.yml diff --git a/changelogs/fragments/54690-openssl_certificate-assertonly.yml b/changelogs/fragments/54690-openssl_certificate-assertonly.yml new file mode 100644 index 00000000000..33c8db2ead1 --- /dev/null +++ b/changelogs/fragments/54690-openssl_certificate-assertonly.yml @@ -0,0 +1,2 @@ +minor_changes: +- "openssl_certificate - the messages of the ``assertonly`` provider with respect to private key and CSR checking are now more precise." diff --git a/lib/ansible/modules/crypto/openssl_certificate.py b/lib/ansible/modules/crypto/openssl_certificate.py index 242fd098291..ee0b226d414 100644 --- a/lib/ansible/modules/crypto/openssl_certificate.py +++ b/lib/ansible/modules/crypto/openssl_certificate.py @@ -533,6 +533,7 @@ backup_file: from random import randint +import abc import datetime import os import traceback @@ -1103,10 +1104,28 @@ class OwnCACertificate(Certificate): return result -class AssertOnlyCertificateCryptography(Certificate): - """Validate the supplied cert, using the cryptography backend""" - def __init__(self, module): - super(AssertOnlyCertificateCryptography, self).__init__(module, 'cryptography') +def compare_sets(subset, superset, equality=False): + if equality: + return set(subset) == set(superset) + else: + return all(x in superset for x in subset) + + +def compare_dicts(subset, superset, equality=False): + if equality: + return subset == superset + else: + return all(superset.get(x) == v for x, v in subset.items()) + + +NO_EXTENSION = 'no extension' + + +class AssertOnlyCertificateBase(Certificate): + + def __init__(self, module, backend): + super(AssertOnlyCertificateBase, self).__init__(module, backend) + self.signature_algorithms = module.params['signature_algorithms'] if module.params['subject']: self.subject = crypto_utils.parse_name_field(module.params['subject']) @@ -1120,226 +1139,256 @@ class AssertOnlyCertificateCryptography(Certificate): self.issuer_strict = module.params['issuer_strict'] self.has_expired = module.params['has_expired'] self.version = module.params['version'] - self.keyUsage = module.params['key_usage'] - self.keyUsage_strict = module.params['key_usage_strict'] - self.extendedKeyUsage = module.params['extended_key_usage'] - self.extendedKeyUsage_strict = module.params['extended_key_usage_strict'] - self.subjectAltName = module.params['subject_alt_name'] - self.subjectAltName_strict = module.params['subject_alt_name_strict'] - self.notBefore = module.params['not_before'], - self.notAfter = module.params['not_after'], - self.valid_at = module.params['valid_at'], - self.invalid_at = module.params['invalid_at'], - self.valid_in = module.params['valid_in'], - self.message = [] + self.key_usage = module.params['key_usage'] + self.key_usage_strict = module.params['key_usage_strict'] + self.extended_key_usage = module.params['extended_key_usage'] + self.extended_key_usage_strict = module.params['extended_key_usage_strict'] + self.subject_alt_name = module.params['subject_alt_name'] + self.subject_alt_name_strict = module.params['subject_alt_name_strict'] + self.not_before = module.params['not_before'] + self.not_after = module.params['not_after'] + self.valid_at = module.params['valid_at'] + self.invalid_at = module.params['invalid_at'] + self.valid_in = module.params['valid_in'] + if self.valid_in and not self.valid_in.startswith("+") and not self.valid_in.startswith("-"): + try: + int(self.valid_in) + except ValueError: + module.fail_json(msg='The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in) + self.valid_in = "+" + self.valid_in + "s" - def assertonly(self): + # Load objects self.cert = crypto_utils.load_certificate(self.path, backend=self.backend) + if self.privatekey_path is not None: + try: + self.privatekey = crypto_utils.load_privatekey( + self.privatekey_path, + self.privatekey_passphrase, + backend=self.backend + ) + except crypto_utils.OpenSSLBadPassphraseError as exc: + raise CertificateError(exc) + if self.csr_path is not None: + self.csr = crypto_utils.load_certificate_request(self.csr_path, backend=self.backend) - def _validate_signature_algorithms(): - if self.signature_algorithms: - if self.cert.signature_algorithm_oid._name not in self.signature_algorithms: - self.message.append( - 'Invalid signature algorithm (got %s, expected one of %s)' % - (self.cert.signature_algorithm_oid._name, self.signature_algorithms) - ) + @abc.abstractmethod + def _validate_privatekey(self): + pass - def _validate_subject(): - if self.subject: - expected_subject = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(sub[0]), value=to_text(sub[1])) - for sub in self.subject]) - cert_subject = self.cert.subject - if (not self.subject_strict and not all(x in cert_subject for x in expected_subject)) or \ - (self.subject_strict and not set(expected_subject) == set(cert_subject)): - self.message.append( - 'Invalid subject component (got %s, expected all of %s to be present)' % - (cert_subject, expected_subject) - ) + @abc.abstractmethod + def _validate_csr_signature(self): + pass - def _validate_issuer(): - if self.issuer: - expected_issuer = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(iss[0]), value=to_text(iss[1])) - for iss in self.issuer]) - cert_issuer = self.cert.issuer - if (not self.issuer_strict and not all(x in cert_issuer for x in expected_issuer)) or \ - (self.issuer_strict and not set(expected_issuer) == set(cert_issuer)): - self.message.append( - 'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer) - ) + @abc.abstractmethod + def _validate_csr_subject(self): + pass - def _validate_has_expired(): - cert_not_after = self.cert.not_valid_after - cert_expired = cert_not_after < datetime.datetime.utcnow() + @abc.abstractmethod + def _validate_csr_extensions(self): + pass - if self.has_expired != cert_expired: - self.message.append( - 'Certificate expiration check failed (certificate expiration is %s, expected %s)' % (cert_expired, self.has_expired) + @abc.abstractmethod + def _validate_signature_algorithms(self): + pass + + @abc.abstractmethod + def _validate_subject(self): + pass + + @abc.abstractmethod + def _validate_issuer(self): + pass + + @abc.abstractmethod + def _validate_has_expired(self): + pass + + @abc.abstractmethod + def _validate_version(self): + pass + + @abc.abstractmethod + def _validate_key_usage(self): + pass + + @abc.abstractmethod + def _validate_extended_key_usage(self): + pass + + @abc.abstractmethod + def _validate_subject_alt_name(self): + pass + + @abc.abstractmethod + def _validate_not_before(self): + pass + + @abc.abstractmethod + def _validate_not_after(self): + pass + + @abc.abstractmethod + def _validate_valid_at(self): + pass + + @abc.abstractmethod + def _validate_invalid_at(self): + pass + + @abc.abstractmethod + def _validate_valid_in(self): + pass + + def assertonly(self, module): + messages = [] + if self.privatekey_path is not None: + if not self._validate_privatekey(): + messages.append( + 'Certificate %s and private key %s do not match' % + (self.path, self.privatekey_path) ) - def _validate_version(): - # FIXME - if self.version: - expected_version = x509.Version(int(self.version) - 1) - if expected_version != self.cert.version: - self.message.append( - 'Invalid certificate version number (got %s, expected %s)' % (self.cert.version, self.version) - ) + if self.csr_path is not None: + if not self._validate_csr_signature(): + messages.append( + 'Certificate %s and CSR %s do not match: private key mismatch' % + (self.path, self.csr_path) + ) + if not self._validate_csr_subject(): + messages.append( + 'Certificate %s and CSR %s do not match: subject mismatch' % + (self.path, self.csr_path) + ) + if not self._validate_csr_extensions(): + messages.append( + 'Certificate %s and CSR %s do not match: extensions mismatch' % + (self.path, self.csr_path) + ) - def _validate_keyUsage(): - if self.keyUsage: - try: - current_keyusage = self.cert.extensions.get_extension_for_class(x509.KeyUsage).value - expected_keyusage = x509.KeyUsage(**crypto_utils.cryptography_parse_key_usage_params(self.keyUsage)) - test_keyusage = dict( - digital_signature=current_keyusage.digital_signature, - content_commitment=current_keyusage.content_commitment, - key_encipherment=current_keyusage.key_encipherment, - data_encipherment=current_keyusage.data_encipherment, - key_agreement=current_keyusage.key_agreement, - key_cert_sign=current_keyusage.key_cert_sign, - crl_sign=current_keyusage.crl_sign, - ) - if test_keyusage['key_agreement']: - test_keyusage.update(dict( - encipher_only=current_keyusage.encipher_only, - decipher_only=current_keyusage.decipher_only - )) - else: - test_keyusage.update(dict( - encipher_only=False, - decipher_only=False - )) + if self.signature_algorithms is not None: + wrong_alg = self._validate_signature_algorithms() + if wrong_alg: + messages.append( + 'Invalid signature algorithm (got %s, expected one of %s)' % + (wrong_alg, self.signature_algorithms) + ) - key_usages = crypto_utils.cryptography_parse_key_usage_params(self.keyUsage) - if (not self.keyUsage_strict and not all(key_usages[x] == test_keyusage[x] for x in key_usages)) or \ - (self.keyUsage_strict and current_keyusage != expected_keyusage): - self.message.append( - 'Invalid keyUsage components (got %s, expected all of %s to be present)' % - ([x for x in test_keyusage if x is True], [x for x in self.keyUsage if x is True]) - ) + if self.subject is not None: + failure = self._validate_subject() + if failure: + dummy, cert_subject = failure + messages.append( + 'Invalid subject component (got %s, expected all of %s to be present)' % + (cert_subject, self.subject) + ) - except cryptography.x509.ExtensionNotFound: - self.message.append('Found no keyUsage extension') + if self.issuer is not None: + failure = self._validate_issuer() + if failure: + dummy, cert_issuer = failure + messages.append( + 'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer) + ) - def _validate_extendedKeyUsage(): - if self.extendedKeyUsage: - try: - current_ext_keyusage = self.cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value - usages = [crypto_utils.cryptography_get_ext_keyusage(usage) for usage in self.extendedKeyUsage] - expected_ext_keyusage = x509.ExtendedKeyUsage(usages) - if (not self.extendedKeyUsage_strict and not all(x in current_ext_keyusage for x in expected_ext_keyusage)) or \ - (self.extendedKeyUsage_strict and not current_ext_keyusage == expected_ext_keyusage): - self.message.append( - 'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % ([xku.value for xku in current_ext_keyusage], - [exku.value for exku in expected_ext_keyusage]) - ) + if self.has_expired is not None: + cert_expired = self._validate_has_expired() + if cert_expired != self.has_expired: + messages.append( + 'Certificate expiration check failed (certificate expiration is %s, expected %s)' % + (cert_expired, self.has_expired) + ) - except cryptography.x509.ExtensionNotFound: - self.message.append('Found no extendedKeyUsage extension') + if self.version is not None: + cert_version = self._validate_version() + if cert_version != self.version: + messages.append( + 'Invalid certificate version number (got %s, expected %s)' % + (cert_version, self.version) + ) - def _validate_subjectAltName(): - if self.subjectAltName: - try: - current_san = self.cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value - expected_san = [crypto_utils.cryptography_get_name(san) for san in self.subjectAltName] - if (not self.subjectAltName_strict and not all(x in current_san for x in expected_san)) or \ - (self.subjectAltName_strict and not set(current_san) == set(expected_san)): - self.message.append( - 'Invalid subjectAltName component (got %s, expected all of %s to be present)' % - (current_san, self.subjectAltName) - ) - except cryptography.x509.ExtensionNotFound: - self.message.append('Found no subjectAltName extension') + if self.key_usage is not None: + failure = self._validate_key_usage() + if failure == NO_EXTENSION: + messages.append('Found no keyUsage extension') + elif failure: + dummy, cert_key_usage = failure + messages.append( + 'Invalid keyUsage components (got %s, expected all of %s to be present)' % + (cert_key_usage, self.key_usage) + ) - def _validate_notBefore(): - if self.notBefore[0]: - # try: - if self.cert.not_valid_before != self.get_relative_time_option(self.notBefore[0], 'not_before'): - self.message.append( - 'Invalid notBefore component (got %s, expected %s to be present)' % (self.cert.not_valid_before, self.notBefore) - ) - # except AttributeError: - # self.message.append(str(self.notBefore)) + if self.extended_key_usage is not None: + failure = self._validate_extended_key_usage() + if failure == NO_EXTENSION: + messages.append('Found no extendedKeyUsage extension') + elif failure: + dummy, ext_cert_key_usage = failure + messages.append( + 'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % (ext_cert_key_usage, self.extended_key_usage) + ) - def _validate_notAfter(): - if self.notAfter[0]: - if self.cert.not_valid_after != self.get_relative_time_option(self.notAfter[0], 'not_after'): - self.message.append( - 'Invalid notAfter component (got %s, expected %s to be present)' % (self.cert.not_valid_after, self.notAfter) - ) + if self.subject_alt_name is not None: + failure = self._validate_subject_alt_name() + if failure == NO_EXTENSION: + messages.append('Found no subjectAltName extension') + elif failure: + dummy, cert_san = failure + messages.append( + 'Invalid subjectAltName component (got %s, expected all of %s to be present)' % + (cert_san, self.subject_alt_name) + ) - def _validate_valid_at(): - if self.valid_at[0]: - rt = self.get_relative_time_option(self.valid_at[0], 'valid_at') - if not (self.cert.not_valid_before <= rt <= self.cert.not_valid_after): - self.message.append( - 'Certificate is not valid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.valid_at, - self.cert.not_valid_before, - self.cert.not_valid_after) - ) + if self.not_before is not None: + cert_not_valid_before = self._validate_not_before() + if cert_not_valid_before != self.get_relative_time_option(self.not_before, 'not_before'): + messages.append( + 'Invalid not_before component (got %s, expected %s to be present)' % + (cert_not_valid_before, self.not_before) + ) - def _validate_invalid_at(): - if self.invalid_at[0]: - if (self.get_relative_time_option(self.invalid_at[0], 'invalid_at') <= self.cert.not_valid_before) \ - or (self.get_relative_time_option(self.invalid_at, 'invalid_at') >= self.cert.not_valid_after): - self.message.append( - 'Certificate is not invalid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.invalid_at, - self.cert.not_valid_before, - self.cert.not_valid_after) - ) + if self.not_after is not None: + cert_not_valid_after = self._validate_not_after() + if cert_not_valid_after != self.get_relative_time_option(self.not_after, 'not_after'): + messages.append( + 'Invalid not_after component (got %s, expected %s to be present)' % + (cert_not_valid_after, self.not_after) + ) - def _validate_valid_in(): - if self.valid_in[0]: - if not self.valid_in[0].startswith("+") and not self.valid_in[0].startswith("-"): - try: - int(self.valid_in[0]) - except ValueError: - raise CertificateError( - 'The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in) - self.valid_in = "+" + self.valid_in + "s" - valid_in_date = self.get_relative_time_option(self.valid_in[0], "valid_in") - if not self.cert.not_valid_before <= valid_in_date <= self.cert.not_valid_after: - self.message.append( - 'Certificate is not valid in %s from now (that would be %s) - notBefore: %s - notAfter: %s' - % (self.valid_in, valid_in_date, - self.cert.not_valid_before, - self.cert.not_valid_after)) + if self.valid_at is not None: + not_before, valid_at, not_after = self._validate_valid_at() + if not (not_before <= valid_at <= not_after): + messages.append( + 'Certificate is not valid for the specified date (%s) - not_before: %s - not_after: %s' % + (self.valid_at, not_before, not_after) + ) - for validation in ['signature_algorithms', 'subject', 'issuer', - 'has_expired', 'version', 'keyUsage', - 'extendedKeyUsage', 'subjectAltName', - 'notBefore', 'notAfter', 'valid_at', 'valid_in', 'invalid_at']: - f_name = locals()['_validate_%s' % validation] - f_name() + if self.invalid_at is not None: + not_before, invalid_at, not_after = self._validate_invalid_at() + if (invalid_at <= not_before) or (invalid_at >= not_after): + messages.append( + 'Certificate is not invalid for the specified date (%s) - not_before: %s - not_after: %s' % + (self.invalid_at, not_before, not_after) + ) + + if self.valid_in is not None: + not_before, valid_in, not_after = self._validate_valid_in() + if not not_before <= valid_in <= not_after: + messages.append( + 'Certificate is not valid in %s from now (that would be %s) - not_before: %s - not_after: %s' % + (self.valid_in, valid_in, not_before, not_after) + ) + return messages def generate(self, module): """Don't generate anything - only assert""" - - self.assertonly() - - try: - if self.privatekey_path and \ - not super(AssertOnlyCertificateCryptography, self).check(module, perms_required=False): - self.message.append( - 'Certificate %s and private key %s do not match' % (self.path, self.privatekey_path) - ) - except CertificateError as e: - self.message.append( - 'Error while reading private key %s: %s' % (self.privatekey_path, str(e)) - ) - - if len(self.message): - module.fail_json(msg=' | '.join(self.message)) + messages = self.assertonly(module) + if messages: + module.fail_json(msg=' | '.join(messages)) def check(self, module, perms_required=False): """Ensure the resource is in its desired state.""" - - parent_check = super(AssertOnlyCertificateCryptography, self).check(module, perms_required) - self.assertonly() - assertonly_check = not len(self.message) - self.message = [] - - return parent_check and assertonly_check + messages = self.assertonly(module) + return len(messages) == 0 def dump(self, check_mode=False): result = { @@ -1351,45 +1400,150 @@ class AssertOnlyCertificateCryptography(Certificate): return result -class AssertOnlyCertificate(Certificate): +class AssertOnlyCertificateCryptography(AssertOnlyCertificateBase): + """Validate the supplied cert, using the cryptography backend""" + def __init__(self, module): + super(AssertOnlyCertificateCryptography, self).__init__(module, 'cryptography') + + def _validate_privatekey(self): + return self.cert.public_key().public_numbers() == self.privatekey.public_key().public_numbers() + + def _validate_csr_signature(self): + if not self.csr.is_signature_valid: + return False + if self.csr.public_key().public_numbers() != self.cert.public_key().public_numbers(): + return False + + def _validate_csr_subject(self): + if self.csr.subject != self.cert.subject: + return False + + def _validate_csr_extensions(self): + cert_exts = self.cert.extensions + csr_exts = self.csr.extensions + if len(cert_exts) != len(csr_exts): + return False + for cert_ext in cert_exts: + try: + csr_ext = csr_exts.get_extension_for_oid(cert_ext.oid) + if cert_ext != csr_ext: + return False + except cryptography.x509.ExtensionNotFound as dummy: + return False + return True + + def _validate_signature_algorithms(self): + if self.cert.signature_algorithm_oid._name not in self.signature_algorithms: + return self.cert.signature_algorithm_oid._name + + def _validate_subject(self): + expected_subject = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(sub[0]), value=to_text(sub[1])) + for sub in self.subject]) + cert_subject = self.cert.subject + if not compare_sets(expected_subject, cert_subject, self.subject_strict): + return expected_subject, cert_subject + + def _validate_issuer(self): + expected_issuer = Name([NameAttribute(oid=crypto_utils.cryptography_get_name_oid(iss[0]), value=to_text(iss[1])) + for iss in self.issuer]) + cert_issuer = self.cert.issuer + if not compare_sets(expected_issuer, cert_issuer, self.issuer_strict): + return self.issuer, cert_issuer + + def _validate_has_expired(self): + cert_not_after = self.cert.not_valid_after + cert_expired = cert_not_after < datetime.datetime.utcnow() + return cert_expired + + def _validate_version(self): + if self.cert.version == x509.Version.v1: + return 1 + if self.cert.version == x509.Version.v3: + return 3 + return "unknown" + + def _validate_key_usage(self): + try: + current_key_usage = self.cert.extensions.get_extension_for_class(x509.KeyUsage).value + test_key_usage = dict( + digital_signature=current_key_usage.digital_signature, + content_commitment=current_key_usage.content_commitment, + key_encipherment=current_key_usage.key_encipherment, + data_encipherment=current_key_usage.data_encipherment, + key_agreement=current_key_usage.key_agreement, + key_cert_sign=current_key_usage.key_cert_sign, + crl_sign=current_key_usage.crl_sign, + encipher_only=False, + decipher_only=False + ) + if test_key_usage['key_agreement']: + test_key_usage.update(dict( + encipher_only=current_key_usage.encipher_only, + decipher_only=current_key_usage.decipher_only + )) + + key_usages = crypto_utils.cryptography_parse_key_usage_params(self.key_usage) + if not compare_dicts(key_usages, test_key_usage, self.key_usage_strict): + return self.key_usage, [x for x in test_key_usage if x is True] + + except cryptography.x509.ExtensionNotFound: + # This is only bad if the user specified a non-empty list + if self.key_usage: + return NO_EXTENSION + + def _validate_extended_key_usage(self): + try: + current_ext_keyusage = self.cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value + usages = [crypto_utils.cryptography_get_ext_keyusage(usage) for usage in self.extended_key_usage] + expected_ext_keyusage = x509.ExtendedKeyUsage(usages) + if not compare_sets(expected_ext_keyusage, current_ext_keyusage, self.extended_key_usage_strict): + return [eku.value for eku in expected_ext_keyusage], [eku.value for eku in current_ext_keyusage] + + except cryptography.x509.ExtensionNotFound: + # This is only bad if the user specified a non-empty list + if self.extended_key_usage: + return NO_EXTENSION + + def _validate_subject_alt_name(self): + try: + current_san = self.cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + expected_san = [crypto_utils.cryptography_get_name(san) for san in self.subject_alt_name] + if not compare_sets(expected_san, current_san, self.subject_alt_name_strict): + return self.subject_alt_name, current_san + except cryptography.x509.ExtensionNotFound: + # This is only bad if the user specified a non-empty list + if self.subject_alt_name: + return NO_EXTENSION + + def _validate_not_before(self): + return self.cert.not_valid_before + + def _validate_not_after(self): + return self.cert.not_valid_after + + def _validate_valid_at(self): + rt = self.get_relative_time_option(self.valid_at, 'valid_at') + return self.cert.not_valid_before, rt, self.cert.not_valid_after + + def _validate_invalid_at(self): + rt = self.get_relative_time_option(self.valid_at, 'valid_at') + return self.cert.not_valid_before, rt, self.cert.not_valid_after + + def _validate_valid_in(self): + valid_in_date = self.get_relative_time_option(self.valid_in, "valid_in") + return self.cert.not_valid_before, valid_in_date, self.cert.not_valid_after + + +class AssertOnlyCertificate(AssertOnlyCertificateBase): """validate the supplied certificate.""" def __init__(self, module): super(AssertOnlyCertificate, self).__init__(module, 'pyopenssl') - self.signature_algorithms = module.params['signature_algorithms'] - if module.params['subject']: - self.subject = crypto_utils.parse_name_field(module.params['subject']) - else: - self.subject = [] - self.subject_strict = module.params['subject_strict'] - if module.params['issuer']: - self.issuer = crypto_utils.parse_name_field(module.params['issuer']) - else: - self.issuer = [] - self.issuer_strict = module.params['issuer_strict'] - self.has_expired = module.params['has_expired'] - self.version = module.params['version'] - self.keyUsage = module.params['key_usage'] - self.keyUsage_strict = module.params['key_usage_strict'] - self.extendedKeyUsage = module.params['extended_key_usage'] - self.extendedKeyUsage_strict = module.params['extended_key_usage_strict'] - self.subjectAltName = module.params['subject_alt_name'] - self.subjectAltName_strict = module.params['subject_alt_name_strict'] - self.notBefore = module.params['not_before'] - self.notAfter = module.params['not_after'] - self.valid_at = module.params['valid_at'] - self.invalid_at = module.params['invalid_at'] - self.valid_in = module.params['valid_in'] - self.message = [] - self._sanitize_inputs() - - def _sanitize_inputs(self): - """Ensure inputs are properly sanitized before comparison.""" - - for param in ['signature_algorithms', 'keyUsage', 'extendedKeyUsage', - 'subjectAltName', 'subject', 'issuer', 'notBefore', - 'notAfter', 'valid_at', 'invalid_at']: + # Ensure inputs are properly sanitized before comparison. + for param in ['signature_algorithms', 'key_usage', 'extended_key_usage', + 'subject_alt_name', 'subject', 'issuer', 'not_before', + 'not_after', 'valid_at', 'invalid_at']: attr = getattr(self, param) if isinstance(attr, list) and attr: if isinstance(attr[0], str): @@ -1403,213 +1557,134 @@ class AssertOnlyCertificate(Certificate): elif isinstance(attr, str): setattr(self, param, to_bytes(attr)) - def assertonly(self): - - self.cert = crypto_utils.load_certificate(self.path) - - def _validate_signature_algorithms(): - if self.signature_algorithms: - if self.cert.get_signature_algorithm() not in self.signature_algorithms: - self.message.append( - 'Invalid signature algorithm (got %s, expected one of %s)' % (self.cert.get_signature_algorithm(), self.signature_algorithms) - ) - - def _validate_subject(): - if self.subject: - expected_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in self.subject] - cert_subject = self.cert.get_subject().get_components() - current_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in cert_subject] - if (not self.subject_strict and not all(x in current_subject for x in expected_subject)) or \ - (self.subject_strict and not set(expected_subject) == set(current_subject)): - self.message.append( - 'Invalid subject component (got %s, expected all of %s to be present)' % (cert_subject, self.subject) - ) - - def _validate_issuer(): - if self.issuer: - expected_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in self.issuer] - cert_issuer = self.cert.get_issuer().get_components() - current_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in cert_issuer] - if (not self.issuer_strict and not all(x in current_issuer for x in expected_issuer)) or \ - (self.issuer_strict and not set(expected_issuer) == set(current_issuer)): - self.message.append( - 'Invalid issuer component (got %s, expected all of %s to be present)' % (cert_issuer, self.issuer) - ) - - def _validate_has_expired(): - # The following 3 lines are the same as the current PyOpenSSL code for cert.has_expired(). - # Older version of PyOpenSSL have a buggy implementation, - # to avoid issues with those we added the code from a more recent release here. - - time_string = to_native(self.cert.get_notAfter()) - not_after = datetime.datetime.strptime(time_string, "%Y%m%d%H%M%SZ") - cert_expired = not_after < datetime.datetime.utcnow() - - if self.has_expired != cert_expired: - self.message.append( - 'Certificate expiration check failed (certificate expiration is %s, expected %s)' % (cert_expired, self.has_expired) - ) - - def _validate_version(): - if self.version: - # Version numbers in certs are off by one: - # v1: 0, v2: 1, v3: 2 ... - if self.version != self.cert.get_version() + 1: - self.message.append( - 'Invalid certificate version number (got %s, expected %s)' % (self.cert.get_version() + 1, self.version) - ) - - def _validate_keyUsage(): - if self.keyUsage: - found = False - for extension_idx in range(0, self.cert.get_extension_count()): - extension = self.cert.get_extension(extension_idx) - if extension.get_short_name() == b'keyUsage': - found = True - keyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.keyUsage] - current_ku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in - to_bytes(extension, errors='surrogate_or_strict').split(b',')] - if (not self.keyUsage_strict and not all(x in current_ku for x in keyUsage)) or \ - (self.keyUsage_strict and not set(keyUsage) == set(current_ku)): - self.message.append( - 'Invalid keyUsage component (got %s, expected all of %s to be present)' % (str(extension).split(', '), self.keyUsage) - ) - if not found: - self.message.append('Found no keyUsage extension') - - def _validate_extendedKeyUsage(): - if self.extendedKeyUsage: - found = False - for extension_idx in range(0, self.cert.get_extension_count()): - extension = self.cert.get_extension(extension_idx) - if extension.get_short_name() == b'extendedKeyUsage': - found = True - extKeyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.extendedKeyUsage] - current_xku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in - to_bytes(extension, errors='surrogate_or_strict').split(b',')] - if (not self.extendedKeyUsage_strict and not all(x in current_xku for x in extKeyUsage)) or \ - (self.extendedKeyUsage_strict and not set(extKeyUsage) == set(current_xku)): - self.message.append( - 'Invalid extendedKeyUsage component (got %s, expected all of %s to be present)' % (str(extension).split(', '), - self.extendedKeyUsage) - ) - if not found: - self.message.append('Found no extendedKeyUsage extension') - - def _validate_subjectAltName(): - if self.subjectAltName: - found = False - for extension_idx in range(0, self.cert.get_extension_count()): - extension = self.cert.get_extension(extension_idx) - if extension.get_short_name() == b'subjectAltName': - found = True - l_altnames = [altname.replace(b'IP Address', b'IP') for altname in - to_bytes(extension, errors='surrogate_or_strict').split(b', ')] - if (not self.subjectAltName_strict and not all(x in l_altnames for x in self.subjectAltName)) or \ - (self.subjectAltName_strict and not set(self.subjectAltName) == set(l_altnames)): - self.message.append( - 'Invalid subjectAltName component (got %s, expected all of %s to be present)' % (l_altnames, self.subjectAltName) - ) - if not found: - self.message.append('Found no subjectAltName extension') - - def _validate_notBefore(): - if self.notBefore: - if self.cert.get_notBefore() != self.notBefore: - self.message.append( - 'Invalid notBefore component (got %s, expected %s to be present)' % (self.cert.get_notBefore(), self.notBefore) - ) - - def _validate_notAfter(): - if self.notAfter: - if self.cert.get_notAfter() != self.notAfter: - self.message.append( - 'Invalid notAfter component (got %s, expected %s to be present)' % (self.cert.get_notAfter(), self.notAfter) - ) - - def _validate_valid_at(): - if self.valid_at: - if not (self.cert.get_notBefore() <= self.valid_at <= self.cert.get_notAfter()): - self.message.append( - 'Certificate is not valid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.valid_at, - self.cert.get_notBefore(), - self.cert.get_notAfter()) - ) - - def _validate_invalid_at(): - if self.invalid_at: - if not (self.invalid_at <= self.cert.get_notBefore() or self.invalid_at >= self.cert.get_notAfter()): - self.message.append( - 'Certificate is not invalid for the specified date (%s) - notBefore: %s - notAfter: %s' % (self.invalid_at, - self.cert.get_notBefore(), - self.cert.get_notAfter()) - ) - - def _validate_valid_in(): - if self.valid_in: - if not self.valid_in.startswith("+") and not self.valid_in.startswith("-"): - try: - int(self.valid_in) - except ValueError: - raise CertificateError( - 'The supplied value for "valid_in" (%s) is not an integer or a valid timespec' % self.valid_in) - self.valid_in = "+" + self.valid_in + "s" - valid_in_asn1 = self.get_relative_time_option(self.valid_in, "valid_in") - valid_in_date = to_bytes(valid_in_asn1, errors='surrogate_or_strict') - if not (self.cert.get_notBefore() <= valid_in_date <= self.cert.get_notAfter()): - self.message.append( - 'Certificate is not valid in %s from now (that would be %s) - notBefore: %s - notAfter: %s' - % (self.valid_in, valid_in_date, - self.cert.get_notBefore(), - self.cert.get_notAfter())) - - for validation in ['signature_algorithms', 'subject', 'issuer', - 'has_expired', 'version', 'keyUsage', - 'extendedKeyUsage', 'subjectAltName', - 'notBefore', 'notAfter', 'valid_at', - 'invalid_at', 'valid_in']: - f_name = locals()['_validate_%s' % validation] - f_name() - - def generate(self, module): - """Don't generate anything - assertonly""" - - self.assertonly() - + def _validate_privatekey(self): + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD) + ctx.use_privatekey(self.privatekey) + ctx.use_certificate(self.cert) try: - if self.privatekey_path and \ - not super(AssertOnlyCertificate, self).check(module, perms_required=False): - self.message.append( - 'Certificate %s and private key %s do not match' % (self.path, self.privatekey_path) - ) - except CertificateError as e: - self.message.append( - 'Error while reading private key %s: %s' % (self.privatekey_path, str(e)) - ) + ctx.check_privatekey() + return True + except OpenSSL.SSL.Error: + return False - if len(self.message): - module.fail_json(msg=' | '.join(self.message)) + def _validate_csr_signature(self): + try: + self.csr.verify(self.cert.get_pubkey()) + except OpenSSL.crypto.Error: + return False - def check(self, module, perms_required=True): - """Ensure the resource is in its desired state.""" + def _validate_csr_subject(self): + if self.csr.get_subject() != self.cert.get_subject(): + return False - parent_check = super(AssertOnlyCertificate, self).check(module, perms_required) - self.assertonly() - assertonly_check = not len(self.message) - self.message = [] + def _validate_csr_extensions(self): + csr_extensions = self.csr.get_extensions() + cert_extension_count = self.cert.get_extension_count() + if len(csr_extensions) != cert_extension_count: + return False + for extension_number in range(0, cert_extension_count): + cert_extension = self.cert.get_extension(extension_number) + csr_extension = filter(lambda extension: extension.get_short_name() == cert_extension.get_short_name(), csr_extensions) + if cert_extension.get_data() != list(csr_extension)[0].get_data(): + return False + return True - return parent_check and assertonly_check + def _validate_signature_algorithms(self): + if self.cert.get_signature_algorithm() not in self.signature_algorithms: + return self.cert.get_signature_algorithm() - def dump(self, check_mode=False): + def _validate_subject(self): + expected_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in self.subject] + cert_subject = self.cert.get_subject().get_components() + current_subject = [(OpenSSL._util.lib.OBJ_txt2nid(sub[0]), sub[1]) for sub in cert_subject] + if not compare_sets(expected_subject, current_subject, self.subject_strict): + return expected_subject, current_subject - result = { - 'changed': self.changed, - 'filename': self.path, - 'privatekey': self.privatekey_path, - 'csr': self.csr_path, - } + def _validate_issuer(self): + expected_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in self.issuer] + cert_issuer = self.cert.get_issuer().get_components() + current_issuer = [(OpenSSL._util.lib.OBJ_txt2nid(iss[0]), iss[1]) for iss in cert_issuer] + if not compare_sets(expected_issuer, current_issuer, self.issuer_strict): + return self.issuer, cert_issuer - return result + def _validate_has_expired(self): + # The following 3 lines are the same as the current PyOpenSSL code for cert.has_expired(). + # Older version of PyOpenSSL have a buggy implementation, + # to avoid issues with those we added the code from a more recent release here. + + time_string = to_native(self.cert.get_notAfter()) + not_after = datetime.datetime.strptime(time_string, "%Y%m%d%H%M%SZ") + cert_expired = not_after < datetime.datetime.utcnow() + return cert_expired + + def _validate_version(self): + # Version numbers in certs are off by one: + # v1: 0, v2: 1, v3: 2 ... + return self.cert.get_version() + 1 + + def _validate_key_usage(self): + found = False + for extension_idx in range(0, self.cert.get_extension_count()): + extension = self.cert.get_extension(extension_idx) + if extension.get_short_name() == b'keyUsage': + found = True + key_usage = [OpenSSL._util.lib.OBJ_txt2nid(key_usage) for key_usage in self.key_usage] + current_ku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in + to_bytes(extension, errors='surrogate_or_strict').split(b',')] + if not compare_sets(key_usage, current_ku, self.key_usage_strict): + return self.key_usage, str(extension).split(', ') + if not found: + # This is only bad if the user specified a non-empty list + if self.key_usage: + return NO_EXTENSION + + def _validate_extended_key_usage(self): + found = False + for extension_idx in range(0, self.cert.get_extension_count()): + extension = self.cert.get_extension(extension_idx) + if extension.get_short_name() == b'extendedKeyUsage': + found = True + extKeyUsage = [OpenSSL._util.lib.OBJ_txt2nid(keyUsage) for keyUsage in self.extended_key_usage] + current_xku = [OpenSSL._util.lib.OBJ_txt2nid(usage.strip()) for usage in + to_bytes(extension, errors='surrogate_or_strict').split(b',')] + if not compare_sets(extKeyUsage, current_xku, self.extended_key_usage_strict): + return self.extended_key_usage, str(extension).split(', ') + if not found: + # This is only bad if the user specified a non-empty list + if self.extended_key_usage: + return NO_EXTENSION + + def _validate_subject_alt_name(self): + found = False + for extension_idx in range(0, self.cert.get_extension_count()): + extension = self.cert.get_extension(extension_idx) + if extension.get_short_name() == b'subjectAltName': + found = True + l_altnames = [altname.replace(b'IP Address', b'IP') for altname in + to_bytes(extension, errors='surrogate_or_strict').split(b', ')] + if not compare_sets(self.subject_alt_name, l_altnames, self.subject_alt_name_strict): + return self.subject_alt_name, l_altnames + if not found: + # This is only bad if the user specified a non-empty list + if self.subject_alt_name: + return NO_EXTENSION + + def _validate_not_before(self): + return self.cert.get_notBefore() + + def _validate_not_after(self): + return self.cert.get_notAfter() + + def _validate_valid_at(self): + return self.cert.get_notBefore(), self.valid_at, self.cert.get_notAfter() + + def _validate_invalid_at(self): + return self.cert.get_notBefore(), self.valid_at, self.cert.get_notAfter() + + def _validate_valid_in(self): + valid_in_asn1 = self.get_relative_time_option(self.valid_in, "valid_in") + valid_in_date = to_bytes(valid_in_asn1, errors='surrogate_or_strict') + return self.cert.get_notBefore(), valid_in_date, self.cert.get_notAfter() class AcmeCertificate(Certificate):