Postgres module_utils: add get_connect_params + unit tests (#58067)

* add get_conn_params

* add get_conn_params: add to the modules
This commit is contained in:
Andrey Klychkov 2019-06-19 16:49:19 +03:00 committed by Toshio Kuratomi
parent 6bace8aa54
commit 64d0559e9f
18 changed files with 207 additions and 106 deletions

View file

@ -66,52 +66,24 @@ def ensure_required_libs(module):
module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter')
def connect_to_db(module, autocommit=False, fail_on_conn=True, warn_db_default=True):
"""Return psycopg2 connection object.
def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True):
"""Connect to a PostgreSQL database.
Keyword arguments:
module -- object of ansible.module_utils.basic.AnsibleModule class
autocommit -- commit automatically (default False)
fail_on_conn -- fail if connection failed or just warn and return None (default True)
warn_db_default -- warn that the default DB is used (default True)
Return psycopg2 connection object.
Args:
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
conn_params (dict) -- dictionary with connection parameters
Kwargs:
autocommit (bool) -- commit automatically (default False)
fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True)
"""
ensure_required_libs(module)
# To use defaults values, keyword arguments must be absent, so
# check which values are empty and don't include in the **kw
# dictionary
params_map = {
"login_host": "host",
"login_user": "user",
"login_password": "password",
"port": "port",
"ssl_mode": "sslmode",
"ca_cert": "sslrootcert"
}
# Might be different in the modules:
if module.params.get('db'):
params_map['db'] = 'database'
elif module.params.get('database'):
params_map['database'] = 'database'
elif module.params.get('login_db'):
params_map['login_db'] = 'database'
else:
if warn_db_default:
module.warn('Database name has not been passed, '
'used default database to connect to.')
kw = dict((params_map[k], v) for (k, v) in iteritems(module.params)
if k in params_map and v != '' and v is not None)
# If a login_unix_socket is specified, incorporate it here.
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
if is_localhost and module.params["login_unix_socket"] != "":
kw["host"] = module.params["login_unix_socket"]
db_connection = None
try:
db_connection = psycopg2.connect(**kw)
db_connection = psycopg2.connect(**conn_params)
if autocommit:
if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'):
db_connection.set_session(autocommit=True)
@ -179,3 +151,49 @@ def exec_sql(obj, query, ddl=False, add_to_executed=True):
except Exception as e:
obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e)))
return False
def get_conn_params(module, params_dict, warn_db_default=True):
"""Get connection parameters from the passed dictionary.
Return a dictionary with parameters to connect to PostgreSQL server.
Args:
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
params_dict (dict) -- dictionary with variables
Kwargs:
warn_db_default (bool) -- warn that the default DB is used (default True)
"""
# To use defaults values, keyword arguments must be absent, so
# check which values are empty and don't include in the return dictionary
params_map = {
"login_host": "host",
"login_user": "user",
"login_password": "password",
"port": "port",
"ssl_mode": "sslmode",
"ca_cert": "sslrootcert"
}
# Might be different in the modules:
if params_dict.get('db'):
params_map['db'] = 'database'
elif params_dict.get('database'):
params_map['database'] = 'database'
elif params_dict.get('login_db'):
params_map['login_db'] = 'database'
else:
if warn_db_default:
module.warn('Database name has not been passed, '
'used default database to connect to.')
kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict)
if k in params_map and v != '' and v is not None)
# If a login_unix_socket is specified, incorporate it here.
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
if is_localhost and params_dict["login_unix_socket"] != "":
kw["host"] = params_dict["login_unix_socket"]
return kw

View file

@ -178,6 +178,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils.six import iteritems
@ -351,7 +352,8 @@ def main():
module.fail_json(msg='src param is necessary with copy_to')
# Connect to DB and make cursor object:
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)
##############

View file

@ -143,7 +143,11 @@ except ImportError:
pass
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
from ansible.module_utils.database import pg_quote_identifier
@ -216,7 +220,8 @@ def main():
cascade = module.params["cascade"]
changed = False
db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
try:

View file

@ -230,6 +230,7 @@ from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -474,7 +475,8 @@ def main():
if cascade and state != 'absent':
module.fail_json(msg="cascade parameter used only with state=absent")
db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
# Set defaults:

View file

@ -475,7 +475,11 @@ except ImportError:
pass
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
@ -502,7 +506,8 @@ class PgDbConn(object):
Note: connection parameters are passed by self.module object.
"""
self.db_conn = connect_to_db(self.module, warn_db_default=False)
conn_params = get_conn_params(self.module, self.module.params, warn_db_default=False)
self.db_conn = connect_to_db(self.module, conn_params)
return self.db_conn.cursor(cursor_factory=DictCursor)
def reconnect(self, dbname):

View file

@ -170,7 +170,11 @@ queries:
'''
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
from ansible.module_utils.database import pg_quote_identifier
@ -254,7 +258,8 @@ def main():
cascade = module.params["cascade"]
fail_on_drop = module.params["fail_on_drop"]
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor()
changed = False

View file

@ -147,6 +147,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -284,7 +285,8 @@ def main():
fail_on_role = module.params['fail_on_role']
state = module.params['state']
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)
##############

View file

@ -161,6 +161,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -415,7 +416,8 @@ def main():
reassign_owned_by = module.params['reassign_owned_by']
fail_on_role = module.params['fail_on_role']
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)
##############

View file

@ -82,6 +82,7 @@ from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -138,7 +139,8 @@ def main():
server_version=dict(),
)
db_connection = connect_to_db(module, fail_on_conn=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, fail_on_conn=False)
if db_connection is not None:
cursor = db_connection.cursor(cursor_factory=DictCursor)

View file

@ -146,7 +146,11 @@ except ImportError:
pass
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
@ -189,7 +193,8 @@ def main():
except Exception as e:
module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e)))
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)
# Prepare args:

View file

@ -129,7 +129,11 @@ except ImportError:
pass
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils.database import SQLParseError, pg_quote_identifier
from ansible.module_utils._text import to_native
@ -234,7 +238,8 @@ def main():
cascade_drop = module.params["cascade_drop"]
changed = False
db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
try:

View file

@ -287,6 +287,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -498,7 +499,8 @@ def main():
# Change autocommit to False if check_mode:
autocommit = not module.check_mode
# Connect to DB and make cursor object:
db_connection = connect_to_db(module, autocommit=autocommit)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=autocommit)
cursor = db_connection.cursor(cursor_factory=DictCursor)
##############

View file

@ -165,7 +165,11 @@ except Exception:
from copy import deepcopy
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
PG_REQ_VER = 90400
@ -304,7 +308,8 @@ def main():
if not value and not reset:
module.fail_json(msg="%s: at least one of value or reset param must be specified" % name)
db_connection = connect_to_db(module, autocommit=True, warn_db_default=False)
conn_params = get_conn_params(module, module.params, warn_db_default=False)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
kw = {}
@ -397,7 +402,7 @@ def main():
# Reconnect and recheck current value:
if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'):
db_connection = connect_to_db(module, autocommit=True)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
res = param_get(cursor, module, name)

View file

@ -152,6 +152,7 @@ from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -242,7 +243,8 @@ def main():
if immediately_reserve and slot_type == 'logical':
module.fail_json(msg="Module parameters immediately_reserve and slot_type=logical are mutually exclusive")
db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
##################################

View file

@ -240,6 +240,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -514,7 +515,8 @@ def main():
if including and not like:
module.fail_json(msg="%s: including param needs like param specified" % table)
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)
if storage_params:

View file

@ -176,6 +176,7 @@ from ansible.module_utils.database import pg_quote_identifier
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
@ -394,7 +395,8 @@ def main():
module.fail_json(msg="state=absent is mutually exclusive location, "
"owner, rename_to, and set")
db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)
# Change autocommit to False if check_mode:

View file

@ -244,7 +244,11 @@ except ImportError:
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.database import pg_quote_identifier, SQLParseError
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six import iteritems
@ -801,7 +805,8 @@ def main():
conn_limit = module.params["conn_limit"]
role_attr_flags = module.params["role_attr_flags"]
db_connection = connect_to_db(module, warn_db_default=False)
conn_params = get_conn_params(module, module.params, warn_db_default=False)
db_connection = connect_to_db(module, conn_params)
cursor = db_connection.cursor(cursor_factory=DictCursor)
try:

View file

@ -6,6 +6,33 @@ import pytest
import ansible.module_utils.postgres as pg
INPUT_DICT = dict(
session_role=dict(default=''),
login_user=dict(default='postgres'),
login_password=dict(default='test', no_log=True),
login_host=dict(default='test'),
login_unix_socket=dict(default=''),
port=dict(type='int', default=5432, aliases=['login_port']),
ssl_mode=dict(
default='prefer',
choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']
),
ca_cert=dict(aliases=['ssl_rootcert']),
)
EXPECTED_DICT = dict(
user=dict(default='postgres'),
password=dict(default='test', no_log=True),
host=dict(default='test'),
port=dict(type='int', default=5432, aliases=['login_port']),
sslmode=dict(
default='prefer',
choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']
),
sslrootcert=dict(aliases=['ssl_rootcert']),
)
class TestPostgresCommonArgSpec():
"""
@ -154,6 +181,24 @@ class TestEnsureReqLibs():
assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg
@pytest.fixture(scope='class')
def m_ansible_module():
"""Return an object of dummy AnsibleModule class."""
class DummyAnsibleModule():
def __init__(self):
self.params = pg.postgres_common_argument_spec()
self.err_msg = ''
self.warn_msg = ''
def fail_json(self, msg):
self.err_msg = msg
def warn(self, msg):
self.warn_msg = msg
return DummyAnsibleModule()
class TestConnectToDb():
"""
@ -168,29 +213,13 @@ class TestConnectToDb():
2. Types of return objects (db_connection and cursor).
"""
@pytest.fixture(scope='class')
def m_ansible_module(self):
"""Return an object of dummy AnsibleModule class."""
class DummyAnsibleModule():
def __init__(self):
self.params = pg.postgres_common_argument_spec()
self.err_msg = ''
self.warn_msg = ''
def fail_json(self, msg):
self.err_msg = msg
def warn(self, msg):
self.warn_msg = msg
return DummyAnsibleModule()
def test_connect_to_db(self, m_ansible_module, monkeypatch, m_psycopg2):
"""Test connect_to_db(), common test."""
monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True)
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
db_connection = pg.connect_to_db(m_ansible_module)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params)
cursor = db_connection.cursor()
# if errors, db_connection returned as None:
assert isinstance(db_connection, DbConnection)
@ -205,7 +234,8 @@ class TestConnectToDb():
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
m_ansible_module.params['session_role'] = 'test_role'
db_connection = pg.connect_to_db(m_ansible_module)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params)
cursor = db_connection.cursor()
# if errors, db_connection returned as None:
assert isinstance(db_connection, DbConnection)
@ -214,25 +244,6 @@ class TestConnectToDb():
# The default behaviour, normal in this case:
assert 'Database name has not been passed' in m_ansible_module.warn_msg
def test_warn_db_default_non_default(self, m_ansible_module, monkeypatch, m_psycopg2):
"""
Test connect_to_db(), warn_db_default arg passed as False (by default is True).
"""
monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True)
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
db_connection = pg.connect_to_db(m_ansible_module, warn_db_default=False)
cursor = db_connection.cursor()
# if errors, db_connection returned as None:
assert isinstance(db_connection, DbConnection)
assert isinstance(cursor, Cursor)
assert m_ansible_module.err_msg == ''
assert m_ansible_module.warn_msg == ''
# pay attention that warn_db_defaul=True has been checked
# in the previous tests by
# assert('Database name has not been passed' in m_ansible_module.warn_msg)
# because of this is the default behavior
def test_fail_on_conn_true(self, m_ansible_module, monkeypatch, m_psycopg2):
"""
Test connect_to_db(), fail_on_conn arg passed as True (the default behavior).
@ -242,7 +253,8 @@ class TestConnectToDb():
m_ansible_module.params['login_user'] = 'Exception' # causes Exception
db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=True)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=True)
assert 'unable to connect to database' in m_ansible_module.err_msg
assert db_connection is None
@ -256,7 +268,8 @@ class TestConnectToDb():
m_ansible_module.params['login_user'] = 'Exception' # causes Exception
db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=False)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=False)
assert m_ansible_module.err_msg == ''
assert 'PostgreSQL server is unavailable' in m_ansible_module.warn_msg
@ -271,7 +284,8 @@ class TestConnectToDb():
# case 1: psycopg2.__version >= 2.4.2 (the default in m_psycopg2)
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
db_connection = pg.connect_to_db(m_ansible_module, autocommit=True)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True)
cursor = db_connection.cursor()
# if errors, db_connection returned as None:
@ -283,10 +297,26 @@ class TestConnectToDb():
m_psycopg2.__version__ = '2.4.1'
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
db_connection = pg.connect_to_db(m_ansible_module, autocommit=True)
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True)
cursor = db_connection.cursor()
# if errors, db_connection returned as None:
assert isinstance(db_connection, DbConnection)
assert isinstance(cursor, Cursor)
assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg
class TestGetConnParams():
"""Namespace for testing get_conn_params() function."""
def test_get_conn_params_def(self, m_ansible_module):
"""Test get_conn_params(), warn_db_default kwarg is default."""
assert pg.get_conn_params(m_ansible_module, INPUT_DICT) == EXPECTED_DICT
assert m_ansible_module.warn_msg == 'Database name has not been passed, used default database to connect to.'
def test_get_conn_params_warn_db_def_false(self, m_ansible_module):
"""Test get_conn_params(), warn_db_default kwarg is False."""
assert pg.get_conn_params(m_ansible_module, INPUT_DICT, warn_db_default=False) == EXPECTED_DICT
assert m_ansible_module.warn_msg == ''