[stable-2.7] Catch sshpass authentication errors and don't retry multiple times to prevent account lockout (#50776)

* Catch SSH authentication errors and don't retry multiple times to prevent account lock out

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Subclass AnsibleAuthenticationFailure from AnsibleConnectionFailure

Use comparison rather than range() because it's much more efficient.

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Add tests

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Make paramiko_ssh connection plugin behave the same way

Signed-off-by: Sam Doran <sdoran@redhat.com>

* Add changelog

Signed-off-by: Sam Doran <sdoran@redhat.com>.
(cherry picked from commit 9d4c0dc111)

Co-authored-by: Sam Doran <sdoran@redhat.com>
Signed-off-by: Sam Doran <sdoran@redhat.com>
This commit is contained in:
Sam Doran 2019-01-23 11:32:25 -05:00 committed by Toshio Kuratomi
parent f67081e97b
commit 44d7c1e23e
5 changed files with 115 additions and 23 deletions

View file

@ -0,0 +1,2 @@
bugfixes:
- ssh connection - do not retry with invalid credentials to prevent account lockout (https://github.com/ansible/ansible/issues/48422)

View file

@ -209,6 +209,11 @@ class AnsibleConnectionFailure(AnsibleRuntimeError):
pass pass
class AnsibleAuthenticationFailure(AnsibleConnectionFailure):
'''invalid username/password/key'''
pass
class AnsibleFilterError(AnsibleRuntimeError): class AnsibleFilterError(AnsibleRuntimeError):
''' a templating failure ''' ''' a templating failure '''
pass pass

View file

@ -141,12 +141,17 @@ from distutils.version import LooseVersion
from binascii import hexlify from binascii import hexlify
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.errors import (
AnsibleAuthenticationFailure,
AnsibleConnectionFailure,
AnsibleError,
AnsibleFileNotFound,
)
from ansible.module_utils.six import iteritems from ansible.module_utils.six import iteritems
from ansible.module_utils.six.moves import input from ansible.module_utils.six.moves import input
from ansible.plugins.connection import ConnectionBase from ansible.plugins.connection import ConnectionBase
from ansible.utils.path import makedirs_safe from ansible.utils.path import makedirs_safe
from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils._text import to_bytes, to_native, to_text
try: try:
from __main__ import display from __main__ import display
@ -358,6 +363,9 @@ class Connection(ConnectionBase):
) )
except paramiko.ssh_exception.BadHostKeyException as e: except paramiko.ssh_exception.BadHostKeyException as e:
raise AnsibleConnectionFailure('host key mismatch for %s' % e.hostname) raise AnsibleConnectionFailure('host key mismatch for %s' % e.hostname)
except paramiko.ssh_exception.AuthenticationException as e:
msg = 'Invalid/incorrect username/password. {0}'.format(to_text(e))
raise AnsibleAuthenticationFailure(msg)
except Exception as e: except Exception as e:
msg = str(e) msg = str(e)
if "PID check failed" in msg: if "PID check failed" in msg:

View file

@ -151,8 +151,8 @@ DOCUMENTATION = '''
- section: ssh_connection - section: ssh_connection
key: retries key: retries
vars: vars:
- name: ansible_ssh_retries - name: ansible_ssh_retries
version_added: '2.7' version_added: '2.7'
port: port:
description: Remote port to connect to. description: Remote port to connect to.
type: int type: int
@ -280,7 +280,12 @@ import time
from functools import wraps from functools import wraps
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleFileNotFound from ansible.errors import (
AnsibleAuthenticationFailure,
AnsibleConnectionFailure,
AnsibleError,
AnsibleFileNotFound,
)
from ansible.errors import AnsibleOptionsError from ansible.errors import AnsibleOptionsError
from ansible.compat import selectors from ansible.compat import selectors
from ansible.module_utils.six import PY3, text_type, binary_type from ansible.module_utils.six import PY3, text_type, binary_type
@ -310,6 +315,55 @@ class AnsibleControlPersistBrokenPipeError(AnsibleError):
pass pass
def _handle_error(remaining_retries, command, return_tuple, no_log, host, display=display):
# sshpass errors
if command == b'sshpass':
# Error 5 is invalid/incorrect password. Raise an exception to prevent retries from locking the account.
if return_tuple[0] == 5:
msg = 'Invalid/incorrect username/password. Skipping remaining {0} retries to prevent account lockout:'.format(remaining_retries)
if remaining_retries <= 0:
msg = 'Invalid/incorrect password:'
if no_log:
msg = '{0} <error censored due to no log>'.format(msg)
else:
msg = '{0} {1}'.format(msg, to_native(return_tuple[2].rstrip()))
raise AnsibleAuthenticationFailure(msg)
# sshpass returns codes are 1-6. We handle 5 previously, so this catches other scenarios.
# No exception is raised, so the connection is retried.
elif return_tuple[0] in [1, 2, 3, 4, 6]:
msg = 'sshpass error:'
if no_log:
msg = '{0} <error censored due to no log>'.format(msg)
else:
msg = '{0} {1}'.format(msg, to_native(return_tuple[2].rstrip()))
if return_tuple[0] == 255:
SSH_ERROR = True
for signature in b_NOT_SSH_ERRORS:
if signature in return_tuple[1]:
SSH_ERROR = False
break
if SSH_ERROR:
msg = "Failed to connect to the host via ssh:"
if no_log:
msg = '{0} <error censored due to no log>'.format(msg)
else:
msg = '{0} {1}'.format(msg, to_native(return_tuple[2]).rstrip())
raise AnsibleConnectionFailure(msg)
# For other errors, no execption is raised so the connection is retried and we only log the messages
if 1 <= return_tuple[0] <= 254:
msg = "Failed to connect to the host via ssh:"
if no_log:
msg = '{0} <error censored due to no log>'.format(msg)
else:
msg = '{0} {1}'.format(msg, to_native(return_tuple[2]).rstrip())
display.vvv(msg, host=host)
def _ssh_retry(func): def _ssh_retry(func):
""" """
Decorator to retry ssh/scp/sftp in the case of a connection failure Decorator to retry ssh/scp/sftp in the case of a connection failure
@ -318,7 +372,8 @@ def _ssh_retry(func):
* an exception is caught * an exception is caught
* ssh returns 255 * ssh returns 255
Will not retry if Will not retry if
* remaining_tries is <2 * sshpass returns 5 (invalid password, to prevent account lockouts)
* remaining_tries is < 2
* retries limit reached * retries limit reached
""" """
@wraps(func) @wraps(func)
@ -336,7 +391,7 @@ def _ssh_retry(func):
try: try:
return_tuple = func(self, *args, **kwargs) return_tuple = func(self, *args, **kwargs)
if self._play_context.no_log: if self._play_context.no_log:
display.vvv('rc=%s, stdout & stderr censored due to no log' % return_tuple[0], host=self.host) display.vvv('rc=%s, stdout and stderr censored due to no log' % return_tuple[0], host=self.host)
else: else:
display.vvv(return_tuple, host=self.host) display.vvv(return_tuple, host=self.host)
# 0 = success # 0 = success
@ -352,24 +407,18 @@ def _ssh_retry(func):
display.vvv(u"RETRYING BECAUSE OF CONTROLPERSIST BROKEN PIPE") display.vvv(u"RETRYING BECAUSE OF CONTROLPERSIST BROKEN PIPE")
return_tuple = func(self, *args, **kwargs) return_tuple = func(self, *args, **kwargs)
if return_tuple[0] == 255: remaining_retries = remaining_tries - attempt - 1
SSH_ERROR = True _handle_error(remaining_retries, cmd[0], return_tuple, self._play_context.no_log, self.host)
for signature in b_NOT_SSH_ERRORS:
if signature in return_tuple[1]:
SSH_ERROR = False
break
if SSH_ERROR:
msg = "Failed to connect to the host via ssh: "
if self._play_context.no_log:
msg += '<error censored due to no log>'
else:
msg += to_native(return_tuple[2])
raise AnsibleConnectionFailure(msg)
break break
# 5 = Invalid/incorrect password from sshpass
except AnsibleAuthenticationFailure as e:
# Raising this exception, which is subclassed from AnsibleConnectionFailure, prevents further retries
raise
except (AnsibleConnectionFailure, Exception) as e: except (AnsibleConnectionFailure, Exception) as e:
if attempt == remaining_tries - 1: if attempt == remaining_tries - 1:
raise raise
else: else:
@ -378,9 +427,9 @@ def _ssh_retry(func):
pause = 30 pause = 30
if isinstance(e, AnsibleConnectionFailure): if isinstance(e, AnsibleConnectionFailure):
msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause) msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt + 1, cmd_summary, pause)
else: else:
msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause) msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt + 1, e, cmd_summary, pause)
display.vv(msg, host=self.host) display.vv(msg, host=self.host)

View file

@ -25,6 +25,7 @@ import pytest
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleAuthenticationFailure
from ansible.compat.selectors import SelectorKey, EVENT_READ from ansible.compat.selectors import SelectorKey, EVENT_READ
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock, PropertyMock from ansible.compat.tests.mock import patch, MagicMock, PropertyMock
@ -501,6 +502,33 @@ class TestSSHConnectionRun(object):
@pytest.mark.usefixtures('mock_run_env') @pytest.mark.usefixtures('mock_run_env')
class TestSSHConnectionRetries(object): class TestSSHConnectionRetries(object):
def test_incorrect_password(self, monkeypatch):
monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False)
monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 5)
monkeypatch.setattr('time.sleep', lambda x: None)
self.mock_popen_res.stdout.read.side_effect = [b'']
self.mock_popen_res.stderr.read.side_effect = [b'Permission denied, please try again.\r\n']
type(self.mock_popen_res).returncode = PropertyMock(side_effect=[5] * 4)
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[],
]
self.mock_selector.get_map.side_effect = lambda: True
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = [b'sshpass', b'-d41', b'ssh', b'-C']
self.conn.get_option = MagicMock()
self.conn.get_option.return_value = True
exception_info = pytest.raises(AnsibleAuthenticationFailure, self.conn.exec_command, 'sshpass', 'some data')
assert exception_info.value.message == ('Invalid/incorrect username/password. Skipping remaining 5 retries to prevent account lockout: '
'Permission denied, please try again.')
assert self.mock_popen.call_count == 1
def test_retry_then_success(self, monkeypatch): def test_retry_then_success(self, monkeypatch):
monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False) monkeypatch.setattr(C, 'HOST_KEY_CHECKING', False)
monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3) monkeypatch.setattr(C, 'ANSIBLE_SSH_RETRIES', 3)