From f2211058826944a4e8e9e9b678d26615a29977b3 Mon Sep 17 00:00:00 2001 From: Nathaniel Case Date: Fri, 10 Aug 2018 09:26:58 -0400 Subject: [PATCH] Prevent data being truncated over persistent connection socket (#43885) * Change how data is sent to the persistent connection socket. We can't rely on readline(), so send the size of the data first. We can then read that many bytes from the stream on the recieving end. * Set pty to noncanonical mode before sending * Now that we send data length, we don't need a sentinel anymore * Copy socket changes to persistent, too * Use os.write instead of fdopen()ing and using that. * Follow pickle with sha1sum of pickle * Swap order of vars and init being passed to ansible-connection --- bin/ansible-connection | 39 ++++++++++---------- lib/ansible/executor/task_executor.py | 35 ++++++++---------- lib/ansible/module_utils/connection.py | 25 +++++++++++++ lib/ansible/plugins/connection/persistent.py | 34 ++++++++--------- 4 files changed, 77 insertions(+), 56 deletions(-) diff --git a/bin/ansible-connection b/bin/ansible-connection index a9b381c63a0..43b3675a217 100755 --- a/bin/ansible-connection +++ b/bin/ansible-connection @@ -12,6 +12,7 @@ except Exception: pass import fcntl +import hashlib import os import signal import socket @@ -36,6 +37,23 @@ from ansible.utils.display import Display from ansible.utils.jsonrpc import JsonRpcServer +def read_stream(byte_stream): + size = int(byte_stream.readline().strip()) + + data = byte_stream.read(size) + if len(data) < size: + raise Exception("EOF found before data was complete") + + data_hash = to_text(byte_stream.readline().strip()) + if data_hash != hashlib.sha1(data).hexdigest(): + raise Exception("Read {0} bytes, but data did not match checksum".format(size)) + + # restore escaped loose \r characters + data = data.replace(br'\r', b'\r') + + return data + + @contextmanager def file_lock(lock_path): """ @@ -204,25 +222,8 @@ def main(): try: # read the play context data via stdin, which means depickling it - cur_line = stdin.readline() - init_data = b'' - - while cur_line.strip() != b'#END_INIT#': - if cur_line == b'': - raise Exception("EOF found before init data was complete") - init_data += cur_line - cur_line = stdin.readline() - - cur_line = stdin.readline() - vars_data = b'' - - while cur_line.strip() != b'#END_VARS#': - if cur_line == b'': - raise Exception("EOF found before vars data was complete") - vars_data += cur_line - cur_line = stdin.readline() - # restore escaped loose \r characters - vars_data = vars_data.replace(br'\r', b'\r') + vars_data = read_stream(stdin) + init_data = read_stream(stdin) if PY3: pc_data = cPickle.loads(init_data, encoding='bytes') diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 0b2c5b4a7f3..a60d2082217 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -10,14 +10,15 @@ import time import json import subprocess import sys +import termios import traceback from ansible import constants as C from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip from ansible.executor.task_result import TaskResult from ansible.module_utils.six import iteritems, string_types, binary_type -from ansible.module_utils.six.moves import cPickle from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.connection import write_to_file_descriptor from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task from ansible.template import Templar @@ -920,28 +921,24 @@ class TaskExecutor: [python, find_file_in_path('ansible-connection'), to_text(os.getppid())], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - stdin = os.fdopen(master, 'wb', 0) os.close(slave) - # Need to force a protocol that is compatible with both py2 and py3. - # That would be protocol=2 or less. - # Also need to force a protocol that excludes certain control chars as - # stdin in this case is a pty and control chars will cause problems. - # that means only protocol=0 will work. - src = cPickle.dumps(self._play_context.serialize(), protocol=0) - stdin.write(src) - stdin.write(b'\n#END_INIT#\n') + # We need to set the pty into noncanonical mode. This ensures that we + # can receive lines longer than 4095 characters (plus newline) without + # truncating. + old = termios.tcgetattr(master) + new = termios.tcgetattr(master) + new[3] = new[3] & ~termios.ICANON - src = cPickle.dumps(variables, protocol=0) - # remaining \r fail to round-trip the socket - src = src.replace(b'\r', br'\r') - stdin.write(src) - stdin.write(b'\n#END_VARS#\n') + try: + termios.tcsetattr(master, termios.TCSANOW, new) + write_to_file_descriptor(master, variables) + write_to_file_descriptor(master, self._play_context.serialize()) - stdin.flush() - - (stdout, stderr) = p.communicate() - stdin.close() + (stdout, stderr) = p.communicate() + finally: + termios.tcsetattr(master, termios.TCSANOW, old) + os.close(master) if p.returncode == 0: result = json.loads(to_text(stdout, errors='surrogate_then_replace')) diff --git a/lib/ansible/module_utils/connection.py b/lib/ansible/module_utils/connection.py index 57d019fe7bf..08017a52a1c 100644 --- a/lib/ansible/module_utils/connection.py +++ b/lib/ansible/module_utils/connection.py @@ -27,6 +27,7 @@ # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +import hashlib import json import socket import struct @@ -36,6 +37,30 @@ import uuid from functools import partial from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils.six import iteritems +from ansible.module_utils.six.moves import cPickle + + +def write_to_file_descriptor(fd, obj): + """Handles making sure all data is properly written to file descriptor fd. + + In particular, that data is encoded in a character stream-friendly way and + that all data gets written before returning. + """ + # Need to force a protocol that is compatible with both py2 and py3. + # That would be protocol=2 or less. + # Also need to force a protocol that excludes certain control chars as + # stdin in this case is a pty and control chars will cause problems. + # that means only protocol=0 will work. + src = cPickle.dumps(obj, protocol=0) + + # raw \r characters will not survive pty round-trip + # They should be rehydrated on the receiving end + src = src.replace(b'\r', br'\r') + data_hash = to_bytes(hashlib.sha1(src).hexdigest()) + + os.write(fd, b'%d\n' % len(src)) + os.write(fd, src) + os.write(fd, b'%s\n' % data_hash) def send_data(s, data): diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py index 96cdb988af6..800a42843c5 100644 --- a/lib/ansible/plugins/connection/persistent.py +++ b/lib/ansible/plugins/connection/persistent.py @@ -34,12 +34,12 @@ import pty import json import subprocess import sys +import termios from ansible import constants as C from ansible.plugins.connection import ConnectionBase from ansible.module_utils._text import to_text -from ansible.module_utils.six.moves import cPickle -from ansible.module_utils.connection import Connection as SocketConnection +from ansible.module_utils.connection import Connection as SocketConnection, write_to_file_descriptor from ansible.errors import AnsibleError try: @@ -109,26 +109,24 @@ class Connection(ConnectionBase): [python, find_file_in_path('ansible-connection'), to_text(os.getppid())], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - stdin = os.fdopen(master, 'wb', 0) os.close(slave) - # Need to force a protocol that is compatible with both py2 and py3. - # That would be protocol=2 or less. - # Also need to force a protocol that excludes certain control chars as - # stdin in this case is a pty and control chars will cause problems. - # that means only protocol=0 will work. - src = cPickle.dumps(self._play_context.serialize(), protocol=0) - stdin.write(src) - stdin.write(b'\n#END_INIT#\n') + # We need to set the pty into noncanonical mode. This ensures that we + # can receive lines longer than 4095 characters (plus newline) without + # truncating. + old = termios.tcgetattr(master) + new = termios.tcgetattr(master) + new[3] = new[3] & ~termios.ICANON - src = cPickle.dumps({'ansible_command_timeout': self.get_option('persistent_command_timeout')}, protocol=0) - stdin.write(src) - stdin.write(b'\n#END_VARS#\n') + try: + termios.tcsetattr(master, termios.TCSANOW, new) + write_to_file_descriptor(master, {'ansible_command_timeout': self.get_option('persistent_command_timeout')}) + write_to_file_descriptor(master, self._play_context.serialize()) - stdin.flush() - - (stdout, stderr) = p.communicate() - stdin.close() + (stdout, stderr) = p.communicate() + finally: + termios.tcsetattr(master, termios.TCSANOW, old) + os.close(master) if p.returncode == 0: result = json.loads(to_text(stdout, errors='surrogate_then_replace'))