ansible/utilities/logic/wait_for.py

489 lines
18 KiB
Python
Raw Normal View History

#!/usr/bin/python
# -*- coding: utf-8 -*-
# (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
import socket
import datetime
import time
import sys
import re
import binascii
HAS_PSUTIL = False
try:
import psutil
HAS_PSUTIL = True
# just because we can import it on Linux doesn't mean we will use it
except ImportError:
pass
DOCUMENTATION = '''
---
module: wait_for
short_description: Waits for a condition before continuing.
description:
- You can wait for a set amount of time C(timeout), this is the default if nothing is specified.
- Waiting for a port to become available is useful for when services
are not immediately available after their init scripts return
which is true of certain Java application servers. It is also
useful when starting guests with the M(virt) module and
needing to pause until they are ready.
- This module can also be used to wait for a regex match a string to be present in a file.
- In 1.6 and later, this module can also be used to wait for a file to be available or
absent on the filesystem.
- In 1.8 and later, this module can also be used to wait for active
connections to be closed before continuing, useful if a node
is being rotated out of a load balancer pool.
version_added: "0.7"
options:
host:
description:
- A resolvable hostname or IP address to wait for
required: false
default: "127.0.0.1"
timeout:
description:
- maximum number of seconds to wait for
required: false
default: 300
connect_timeout:
description:
- maximum number of seconds to wait for a connection to happen before closing and retrying
required: false
default: 5
delay:
description:
- number of seconds to wait before starting to poll
required: false
default: 0
port:
description:
- port number to poll
required: false
state:
description:
- either C(present), C(started), or C(stopped), C(absent), or C(drained)
- When checking a port C(started) will ensure the port is open, C(stopped) will check that it is closed, C(drained) will check for active connections
- When checking for a file or a search string C(present) or C(started) will ensure that the file or string is present before continuing, C(absent) will check that file is absent or removed
choices: [ "present", "started", "stopped", "absent", "drained" ]
default: "started"
path:
version_added: "1.4"
required: false
description:
- path to a file on the filesytem that must exist before continuing
search_regex:
version_added: "1.4"
required: false
description:
- Can be used to match a string in either a file or a socket connection. Defaults to a multiline regex.
exclude_hosts:
version_added: "1.8"
required: false
description:
- list of hosts or IPs to ignore when looking for active TCP connections for C(drained) state
notes:
- The ability to use search_regex with a port connection was added in 1.7.
requirements: []
author:
- "Jeroen Hoekx (@jhoekx)"
- "John Jarvis (@jarv)"
- "Andrii Radyk (@AnderEnder)"
'''
2013-06-01 18:38:16 -04:00
EXAMPLES = '''
2013-06-01 18:38:16 -04:00
# wait 300 seconds for port 8000 to become open on the host, don't start checking for 10 seconds
- wait_for: port=8000 delay=10
# wait 300 seconds for port 8000 of any IP to close active connections, don't start checking for 10 seconds
- wait_for: host=0.0.0.0 port=8000 delay=10 state=drained
# wait 300 seconds for port 8000 of any IP to close active connections, ignoring connections for specified hosts
- wait_for: host=0.0.0.0 port=8000 state=drained exclude_hosts=10.2.1.2,10.2.1.3
# wait until the file /tmp/foo is present before continuing
- wait_for: path=/tmp/foo
# wait until the string "completed" is in the file /tmp/foo before continuing
- wait_for: path=/tmp/foo search_regex=completed
# wait until the lock file is removed
- wait_for: path=/var/lock/file.lock state=absent
# wait until the process is finished and pid was destroyed
- wait_for: path=/proc/3466/status state=absent
# wait 300 seconds for port 22 to become open and contain "OpenSSH", don't assume the inventory_hostname is resolvable
# and don't start checking for 10 seconds
- local_action: wait_for port=22 host="{{ ansible_ssh_host | default(inventory_hostname) }}" search_regex=OpenSSH delay=10
2013-06-01 18:38:16 -04:00
'''
class TCPConnectionInfo(object):
"""
This is a generic TCP Connection Info strategy class that relies
on the psutil module, which is not ideal for targets, but necessary
for cross platform support.
A subclass may wish to override some or all of these methods.
- _get_exclude_ips()
- get_active_connections()
All subclasses MUST define platform and distribution (which may be None).
"""
platform = 'Generic'
distribution = None
match_all_ips = {
socket.AF_INET: '0.0.0.0',
socket.AF_INET6: '::',
}
connection_states = {
'01': 'ESTABLISHED',
'02': 'SYN_SENT',
'03': 'SYN_RECV',
'04': 'FIN_WAIT1',
'05': 'FIN_WAIT2',
'06': 'TIME_WAIT',
}
def __new__(cls, *args, **kwargs):
return load_platform_subclass(TCPConnectionInfo, args, kwargs)
def __init__(self, module):
self.module = module
(self.family, self.ip) = _convert_host_to_ip(self.module.params['host'])
self.port = int(self.module.params['port'])
self.exclude_ips = self._get_exclude_ips()
if not HAS_PSUTIL:
module.fail_json(msg="psutil module required for wait_for")
def _get_exclude_ips(self):
if self.module.params['exclude_hosts'] is None:
return []
exclude_hosts = self.module.params['exclude_hosts']
return [ _convert_host_to_hex(h)[1] for h in exclude_hosts ]
def get_active_connections_count(self):
active_connections = 0
for p in psutil.process_iter():
connections = p.get_connections(kind='inet')
for conn in connections:
if conn.status not in self.connection_states.values():
continue
(local_ip, local_port) = conn.local_address
if self.port == local_port and self.ip in [self.match_all_ips[self.family], local_ip]:
(remote_ip, remote_port) = conn.remote_address
if remote_ip not in self.exclude_ips:
active_connections += 1
return active_connections
# ===========================================
# Subclass: Linux
class LinuxTCPConnectionInfo(TCPConnectionInfo):
"""
This is a TCP Connection Info evaluation strategy class
that utilizes information from Linux's procfs. While less universal,
does allow Linux targets to not require an additional library.
"""
platform = 'Linux'
distribution = None
source_file = {
socket.AF_INET: '/proc/net/tcp',
socket.AF_INET6: '/proc/net/tcp6'
}
match_all_ips = {
socket.AF_INET: '00000000',
socket.AF_INET6: '00000000000000000000000000000000',
}
local_address_field = 1
remote_address_field = 2
connection_state_field = 3
def __init__(self, module):
self.module = module
(self.family, self.ip) = _convert_host_to_hex(module.params['host'])
self.port = "%0.4X" % int(module.params['port'])
self.exclude_ips = self._get_exclude_ips()
def _get_exclude_ips(self):
if self.module.params['exclude_hosts'] is None:
return []
exclude_hosts = self.module.params['exclude_hosts']
return [ _convert_host_to_hex(h) for h in exclude_hosts ]
def get_active_connections_count(self):
active_connections = 0
f = open(self.source_file[self.family])
for tcp_connection in f.readlines():
tcp_connection = tcp_connection.strip().split()
if tcp_connection[self.local_address_field] == 'local_address':
continue
if tcp_connection[self.connection_state_field] not in self.connection_states:
continue
(local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
if self.port == local_port and self.ip in [self.match_all_ips[self.family], local_ip]:
(remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
if remote_ip not in self.exclude_ips:
active_connections += 1
f.close()
return active_connections
def _convert_host_to_ip(host):
"""
Perform forward DNS resolution on host, IP will give the same IP
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
Tuple containing address family and IP
"""
addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)[0]
return (addrinfo[0], addrinfo[4][0])
def _convert_host_to_hex(host):
"""
Convert the provided host to the format in /proc/net/tcp*
/proc/net/tcp uses little-endian four byte hex for ipv4
/proc/net/tcp6 uses little-endian per 4B word for ipv6
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
Tuple containing address family and the little-endian converted host
"""
(family, ip) = _convert_host_to_ip(host)
hexed = binascii.hexlify(socket.inet_pton(family, ip)).upper()
if family == socket.AF_INET:
hexed = _little_endian_convert_32bit(hexed)
elif family == socket.AF_INET6:
# xrange loops through each 8 character (4B) set in the 128bit total
hexed = "".join([ _little_endian_convert_32bit(hexed[x:x+8]) for x in xrange(0, 32, 8) ])
return (family, hexed)
def _little_endian_convert_32bit(block):
"""
Convert to little-endian, effectively transposing
the order of the four byte word
12345678 -> 78563412
Args:
block: String containing a 4 byte hex representation
Returns:
String containing the little-endian converted block
"""
# xrange starts at 6, and increments by -2 until it reaches -2
# which lets us start at the end of the string block and work to the begining
return "".join([ block[x:x+2] for x in xrange(6, -2, -2) ])
2015-01-05 21:05:01 +03:00
def _create_connection( (host, port), connect_timeout):
"""
Connect to a 2-tuple (host, port) and return
the socket object.
Args:
2-tuple (host, port) and connection timeout
Returns:
Socket object
"""
if sys.version_info < (2, 6):
(family, _) = _convert_host_to_ip(host)
connect_socket = socket.socket(family, socket.SOCK_STREAM)
connect_socket.settimeout(connect_timeout)
connect_socket.connect( (host, port) )
else:
connect_socket = socket.create_connection( (host, port), connect_timeout)
return connect_socket
def main():
module = AnsibleModule(
argument_spec = dict(
host=dict(default='127.0.0.1'),
timeout=dict(default=300),
connect_timeout=dict(default=5),
delay=dict(default=0),
port=dict(default=None),
path=dict(default=None),
search_regex=dict(default=None),
state=dict(default='started', choices=['started', 'stopped', 'present', 'absent', 'drained']),
2014-11-11 15:34:55 -05:00
exclude_hosts=dict(default=None, type='list')
),
)
params = module.params
host = params['host']
timeout = int(params['timeout'])
connect_timeout = int(params['connect_timeout'])
delay = int(params['delay'])
if params['port']:
port = int(params['port'])
else:
port = None
state = params['state']
path = params['path']
search_regex = params['search_regex']
if port and path:
module.fail_json(msg="port and path parameter can not both be passed to wait_for")
if path and state == 'stopped':
module.fail_json(msg="state=stopped should only be used for checking a port in the wait_for module")
if path and state == 'drained':
module.fail_json(msg="state=drained should only be used for checking a port in the wait_for module")
if params['exclude_hosts'] is not None and state != 'drained':
module.fail_json(msg="exclude_hosts should only be with state=drained")
start = datetime.datetime.now()
if delay:
time.sleep(delay)
if not port and not path and state != 'drained':
time.sleep(timeout)
elif state in [ 'stopped', 'absent' ]:
### first wait for the stop condition
end = start + datetime.timedelta(seconds=timeout)
while datetime.datetime.now() < end:
if path:
try:
f = open(path)
f.close()
time.sleep(1)
pass
except IOError:
break
elif port:
try:
2015-01-05 21:05:01 +03:00
s = _create_connection( (host, port), connect_timeout)
s.shutdown(socket.SHUT_RDWR)
s.close()
time.sleep(1)
except:
break
else:
time.sleep(1)
else:
elapsed = datetime.datetime.now() - start
if port:
module.fail_json(msg="Timeout when waiting for %s:%s to stop." % (host, port), elapsed=elapsed.seconds)
elif path:
module.fail_json(msg="Timeout when waiting for %s to be absent." % (path), elapsed=elapsed.seconds)
elif state in ['started', 'present']:
### wait for start condition
end = start + datetime.timedelta(seconds=timeout)
while datetime.datetime.now() < end:
if path:
try:
os.stat(path)
if search_regex:
try:
f = open(path)
try:
if re.search(search_regex, f.read(), re.MULTILINE):
break
else:
time.sleep(1)
finally:
f.close()
except IOError:
time.sleep(1)
pass
else:
break
except OSError, e:
# File not present
if e.errno == 2:
time.sleep(1)
else:
elapsed = datetime.datetime.now() - start
module.fail_json(msg="Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
elif port:
try:
2015-01-05 21:05:01 +03:00
s = _create_connection( (host, port), connect_timeout)
if search_regex:
data = ''
matched = False
while 1:
data += s.recv(1024)
if not data:
break
elif re.search(search_regex, data, re.MULTILINE):
matched = True
break
if matched:
s.shutdown(socket.SHUT_RDWR)
s.close()
break
else:
s.shutdown(socket.SHUT_RDWR)
s.close()
break
except:
time.sleep(1)
pass
else:
time.sleep(1)
else:
elapsed = datetime.datetime.now() - start
if port:
if search_regex:
module.fail_json(msg="Timeout when waiting for search string %s in %s:%s" % (search_regex, host, port), elapsed=elapsed.seconds)
else:
module.fail_json(msg="Timeout when waiting for %s:%s" % (host, port), elapsed=elapsed.seconds)
elif path:
if search_regex:
module.fail_json(msg="Timeout when waiting for search string %s in %s" % (search_regex, path), elapsed=elapsed.seconds)
else:
module.fail_json(msg="Timeout when waiting for file %s" % (path), elapsed=elapsed.seconds)
elif state == 'drained':
### wait until all active connections are gone
end = start + datetime.timedelta(seconds=timeout)
tcpconns = TCPConnectionInfo(module)
while datetime.datetime.now() < end:
try:
if tcpconns.get_active_connections_count() == 0:
break
except IOError:
pass
time.sleep(1)
else:
elapsed = datetime.datetime.now() - start
module.fail_json(msg="Timeout when waiting for %s:%s to drain" % (host, port), elapsed=elapsed.seconds)
elapsed = datetime.datetime.now() - start
module.exit_json(state=state, port=port, search_regex=search_regex, path=path, elapsed=elapsed.seconds)
2013-12-02 15:13:49 -05:00
# import module snippets
from ansible.module_utils.basic import *
main()