Postgres module_utils: add get_connect_params + unit tests (#58067)
* add get_conn_params * add get_conn_params: add to the modules
This commit is contained in:
parent
6bace8aa54
commit
64d0559e9f
18 changed files with 207 additions and 106 deletions
|
@ -66,52 +66,24 @@ def ensure_required_libs(module):
|
|||
module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter')
|
||||
|
||||
|
||||
def connect_to_db(module, autocommit=False, fail_on_conn=True, warn_db_default=True):
|
||||
"""Return psycopg2 connection object.
|
||||
def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True):
|
||||
"""Connect to a PostgreSQL database.
|
||||
|
||||
Keyword arguments:
|
||||
module -- object of ansible.module_utils.basic.AnsibleModule class
|
||||
autocommit -- commit automatically (default False)
|
||||
fail_on_conn -- fail if connection failed or just warn and return None (default True)
|
||||
warn_db_default -- warn that the default DB is used (default True)
|
||||
Return psycopg2 connection object.
|
||||
|
||||
Args:
|
||||
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
|
||||
conn_params (dict) -- dictionary with connection parameters
|
||||
|
||||
Kwargs:
|
||||
autocommit (bool) -- commit automatically (default False)
|
||||
fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True)
|
||||
"""
|
||||
ensure_required_libs(module)
|
||||
|
||||
# To use defaults values, keyword arguments must be absent, so
|
||||
# check which values are empty and don't include in the **kw
|
||||
# dictionary
|
||||
params_map = {
|
||||
"login_host": "host",
|
||||
"login_user": "user",
|
||||
"login_password": "password",
|
||||
"port": "port",
|
||||
"ssl_mode": "sslmode",
|
||||
"ca_cert": "sslrootcert"
|
||||
}
|
||||
|
||||
# Might be different in the modules:
|
||||
if module.params.get('db'):
|
||||
params_map['db'] = 'database'
|
||||
elif module.params.get('database'):
|
||||
params_map['database'] = 'database'
|
||||
elif module.params.get('login_db'):
|
||||
params_map['login_db'] = 'database'
|
||||
else:
|
||||
if warn_db_default:
|
||||
module.warn('Database name has not been passed, '
|
||||
'used default database to connect to.')
|
||||
|
||||
kw = dict((params_map[k], v) for (k, v) in iteritems(module.params)
|
||||
if k in params_map and v != '' and v is not None)
|
||||
|
||||
# If a login_unix_socket is specified, incorporate it here.
|
||||
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
|
||||
if is_localhost and module.params["login_unix_socket"] != "":
|
||||
kw["host"] = module.params["login_unix_socket"]
|
||||
|
||||
db_connection = None
|
||||
try:
|
||||
db_connection = psycopg2.connect(**kw)
|
||||
db_connection = psycopg2.connect(**conn_params)
|
||||
if autocommit:
|
||||
if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'):
|
||||
db_connection.set_session(autocommit=True)
|
||||
|
@ -179,3 +151,49 @@ def exec_sql(obj, query, ddl=False, add_to_executed=True):
|
|||
except Exception as e:
|
||||
obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e)))
|
||||
return False
|
||||
|
||||
|
||||
def get_conn_params(module, params_dict, warn_db_default=True):
|
||||
"""Get connection parameters from the passed dictionary.
|
||||
|
||||
Return a dictionary with parameters to connect to PostgreSQL server.
|
||||
|
||||
Args:
|
||||
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
|
||||
params_dict (dict) -- dictionary with variables
|
||||
|
||||
Kwargs:
|
||||
warn_db_default (bool) -- warn that the default DB is used (default True)
|
||||
"""
|
||||
# To use defaults values, keyword arguments must be absent, so
|
||||
# check which values are empty and don't include in the return dictionary
|
||||
params_map = {
|
||||
"login_host": "host",
|
||||
"login_user": "user",
|
||||
"login_password": "password",
|
||||
"port": "port",
|
||||
"ssl_mode": "sslmode",
|
||||
"ca_cert": "sslrootcert"
|
||||
}
|
||||
|
||||
# Might be different in the modules:
|
||||
if params_dict.get('db'):
|
||||
params_map['db'] = 'database'
|
||||
elif params_dict.get('database'):
|
||||
params_map['database'] = 'database'
|
||||
elif params_dict.get('login_db'):
|
||||
params_map['login_db'] = 'database'
|
||||
else:
|
||||
if warn_db_default:
|
||||
module.warn('Database name has not been passed, '
|
||||
'used default database to connect to.')
|
||||
|
||||
kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict)
|
||||
if k in params_map and v != '' and v is not None)
|
||||
|
||||
# If a login_unix_socket is specified, incorporate it here.
|
||||
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
|
||||
if is_localhost and params_dict["login_unix_socket"] != "":
|
||||
kw["host"] = params_dict["login_unix_socket"]
|
||||
|
||||
return kw
|
||||
|
|
|
@ -178,6 +178,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils.six import iteritems
|
||||
|
@ -351,7 +352,8 @@ def main():
|
|||
module.fail_json(msg='src param is necessary with copy_to')
|
||||
|
||||
# Connect to DB and make cursor object:
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
##############
|
||||
|
|
|
@ -143,7 +143,11 @@ except ImportError:
|
|||
pass
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_native
|
||||
from ansible.module_utils.database import pg_quote_identifier
|
||||
|
||||
|
@ -216,7 +220,8 @@ def main():
|
|||
cascade = module.params["cascade"]
|
||||
changed = False
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
try:
|
||||
|
|
|
@ -230,6 +230,7 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -474,7 +475,8 @@ def main():
|
|||
if cascade and state != 'absent':
|
||||
module.fail_json(msg="cascade parameter used only with state=absent")
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Set defaults:
|
||||
|
|
|
@ -475,7 +475,11 @@ except ImportError:
|
|||
pass
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_native
|
||||
|
||||
|
||||
|
@ -502,7 +506,8 @@ class PgDbConn(object):
|
|||
|
||||
Note: connection parameters are passed by self.module object.
|
||||
"""
|
||||
self.db_conn = connect_to_db(self.module, warn_db_default=False)
|
||||
conn_params = get_conn_params(self.module, self.module.params, warn_db_default=False)
|
||||
self.db_conn = connect_to_db(self.module, conn_params)
|
||||
return self.db_conn.cursor(cursor_factory=DictCursor)
|
||||
|
||||
def reconnect(self, dbname):
|
||||
|
|
|
@ -170,7 +170,11 @@ queries:
|
|||
'''
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_native
|
||||
from ansible.module_utils.database import pg_quote_identifier
|
||||
|
||||
|
@ -254,7 +258,8 @@ def main():
|
|||
cascade = module.params["cascade"]
|
||||
fail_on_drop = module.params["fail_on_drop"]
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor()
|
||||
|
||||
changed = False
|
||||
|
|
|
@ -147,6 +147,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -284,7 +285,8 @@ def main():
|
|||
fail_on_role = module.params['fail_on_role']
|
||||
state = module.params['state']
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
##############
|
||||
|
|
|
@ -161,6 +161,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -415,7 +416,8 @@ def main():
|
|||
reassign_owned_by = module.params['reassign_owned_by']
|
||||
fail_on_role = module.params['fail_on_role']
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
##############
|
||||
|
|
|
@ -82,6 +82,7 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -138,7 +139,8 @@ def main():
|
|||
server_version=dict(),
|
||||
)
|
||||
|
||||
db_connection = connect_to_db(module, fail_on_conn=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, fail_on_conn=False)
|
||||
|
||||
if db_connection is not None:
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
|
|
@ -146,7 +146,11 @@ except ImportError:
|
|||
pass
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_native
|
||||
|
||||
|
||||
|
@ -189,7 +193,8 @@ def main():
|
|||
except Exception as e:
|
||||
module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e)))
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Prepare args:
|
||||
|
|
|
@ -129,7 +129,11 @@ except ImportError:
|
|||
pass
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils.database import SQLParseError, pg_quote_identifier
|
||||
from ansible.module_utils._text import to_native
|
||||
|
||||
|
@ -234,7 +238,8 @@ def main():
|
|||
cascade_drop = module.params["cascade_drop"]
|
||||
changed = False
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
try:
|
||||
|
|
|
@ -287,6 +287,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -498,7 +499,8 @@ def main():
|
|||
# Change autocommit to False if check_mode:
|
||||
autocommit = not module.check_mode
|
||||
# Connect to DB and make cursor object:
|
||||
db_connection = connect_to_db(module, autocommit=autocommit)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=autocommit)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
##############
|
||||
|
|
|
@ -165,7 +165,11 @@ except Exception:
|
|||
from copy import deepcopy
|
||||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_native
|
||||
|
||||
PG_REQ_VER = 90400
|
||||
|
@ -304,7 +308,8 @@ def main():
|
|||
if not value and not reset:
|
||||
module.fail_json(msg="%s: at least one of value or reset param must be specified" % name)
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True, warn_db_default=False)
|
||||
conn_params = get_conn_params(module, module.params, warn_db_default=False)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
kw = {}
|
||||
|
@ -397,7 +402,7 @@ def main():
|
|||
|
||||
# Reconnect and recheck current value:
|
||||
if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'):
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
res = param_get(cursor, module, name)
|
||||
|
|
|
@ -152,6 +152,7 @@ from ansible.module_utils.basic import AnsibleModule
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -242,7 +243,8 @@ def main():
|
|||
if immediately_reserve and slot_type == 'logical':
|
||||
module.fail_json(msg="Module parameters immediately_reserve and slot_type=logical are mutually exclusive")
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
##################################
|
||||
|
|
|
@ -240,6 +240,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -514,7 +515,8 @@ def main():
|
|||
if including and not like:
|
||||
module.fail_json(msg="%s: including param needs like param specified" % table)
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=False)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=False)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
if storage_params:
|
||||
|
|
|
@ -176,6 +176,7 @@ from ansible.module_utils.database import pg_quote_identifier
|
|||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
exec_sql,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
|
||||
|
@ -394,7 +395,8 @@ def main():
|
|||
module.fail_json(msg="state=absent is mutually exclusive location, "
|
||||
"owner, rename_to, and set")
|
||||
|
||||
db_connection = connect_to_db(module, autocommit=True)
|
||||
conn_params = get_conn_params(module, module.params)
|
||||
db_connection = connect_to_db(module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Change autocommit to False if check_mode:
|
||||
|
|
|
@ -244,7 +244,11 @@ except ImportError:
|
|||
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
from ansible.module_utils.database import pg_quote_identifier, SQLParseError
|
||||
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
|
||||
from ansible.module_utils.postgres import (
|
||||
connect_to_db,
|
||||
get_conn_params,
|
||||
postgres_common_argument_spec,
|
||||
)
|
||||
from ansible.module_utils._text import to_bytes, to_native
|
||||
from ansible.module_utils.six import iteritems
|
||||
|
||||
|
@ -801,7 +805,8 @@ def main():
|
|||
conn_limit = module.params["conn_limit"]
|
||||
role_attr_flags = module.params["role_attr_flags"]
|
||||
|
||||
db_connection = connect_to_db(module, warn_db_default=False)
|
||||
conn_params = get_conn_params(module, module.params, warn_db_default=False)
|
||||
db_connection = connect_to_db(module, conn_params)
|
||||
cursor = db_connection.cursor(cursor_factory=DictCursor)
|
||||
|
||||
try:
|
||||
|
|
|
@ -6,6 +6,33 @@ import pytest
|
|||
import ansible.module_utils.postgres as pg
|
||||
|
||||
|
||||
INPUT_DICT = dict(
|
||||
session_role=dict(default=''),
|
||||
login_user=dict(default='postgres'),
|
||||
login_password=dict(default='test', no_log=True),
|
||||
login_host=dict(default='test'),
|
||||
login_unix_socket=dict(default=''),
|
||||
port=dict(type='int', default=5432, aliases=['login_port']),
|
||||
ssl_mode=dict(
|
||||
default='prefer',
|
||||
choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']
|
||||
),
|
||||
ca_cert=dict(aliases=['ssl_rootcert']),
|
||||
)
|
||||
|
||||
EXPECTED_DICT = dict(
|
||||
user=dict(default='postgres'),
|
||||
password=dict(default='test', no_log=True),
|
||||
host=dict(default='test'),
|
||||
port=dict(type='int', default=5432, aliases=['login_port']),
|
||||
sslmode=dict(
|
||||
default='prefer',
|
||||
choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']
|
||||
),
|
||||
sslrootcert=dict(aliases=['ssl_rootcert']),
|
||||
)
|
||||
|
||||
|
||||
class TestPostgresCommonArgSpec():
|
||||
|
||||
"""
|
||||
|
@ -154,6 +181,24 @@ class TestEnsureReqLibs():
|
|||
assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg
|
||||
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def m_ansible_module():
|
||||
"""Return an object of dummy AnsibleModule class."""
|
||||
class DummyAnsibleModule():
|
||||
def __init__(self):
|
||||
self.params = pg.postgres_common_argument_spec()
|
||||
self.err_msg = ''
|
||||
self.warn_msg = ''
|
||||
|
||||
def fail_json(self, msg):
|
||||
self.err_msg = msg
|
||||
|
||||
def warn(self, msg):
|
||||
self.warn_msg = msg
|
||||
|
||||
return DummyAnsibleModule()
|
||||
|
||||
|
||||
class TestConnectToDb():
|
||||
|
||||
"""
|
||||
|
@ -168,29 +213,13 @@ class TestConnectToDb():
|
|||
2. Types of return objects (db_connection and cursor).
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def m_ansible_module(self):
|
||||
"""Return an object of dummy AnsibleModule class."""
|
||||
class DummyAnsibleModule():
|
||||
def __init__(self):
|
||||
self.params = pg.postgres_common_argument_spec()
|
||||
self.err_msg = ''
|
||||
self.warn_msg = ''
|
||||
|
||||
def fail_json(self, msg):
|
||||
self.err_msg = msg
|
||||
|
||||
def warn(self, msg):
|
||||
self.warn_msg = msg
|
||||
|
||||
return DummyAnsibleModule()
|
||||
|
||||
def test_connect_to_db(self, m_ansible_module, monkeypatch, m_psycopg2):
|
||||
"""Test connect_to_db(), common test."""
|
||||
monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True)
|
||||
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params)
|
||||
cursor = db_connection.cursor()
|
||||
# if errors, db_connection returned as None:
|
||||
assert isinstance(db_connection, DbConnection)
|
||||
|
@ -205,7 +234,8 @@ class TestConnectToDb():
|
|||
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
|
||||
|
||||
m_ansible_module.params['session_role'] = 'test_role'
|
||||
db_connection = pg.connect_to_db(m_ansible_module)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params)
|
||||
cursor = db_connection.cursor()
|
||||
# if errors, db_connection returned as None:
|
||||
assert isinstance(db_connection, DbConnection)
|
||||
|
@ -214,25 +244,6 @@ class TestConnectToDb():
|
|||
# The default behaviour, normal in this case:
|
||||
assert 'Database name has not been passed' in m_ansible_module.warn_msg
|
||||
|
||||
def test_warn_db_default_non_default(self, m_ansible_module, monkeypatch, m_psycopg2):
|
||||
"""
|
||||
Test connect_to_db(), warn_db_default arg passed as False (by default is True).
|
||||
"""
|
||||
monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True)
|
||||
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module, warn_db_default=False)
|
||||
cursor = db_connection.cursor()
|
||||
# if errors, db_connection returned as None:
|
||||
assert isinstance(db_connection, DbConnection)
|
||||
assert isinstance(cursor, Cursor)
|
||||
assert m_ansible_module.err_msg == ''
|
||||
assert m_ansible_module.warn_msg == ''
|
||||
# pay attention that warn_db_defaul=True has been checked
|
||||
# in the previous tests by
|
||||
# assert('Database name has not been passed' in m_ansible_module.warn_msg)
|
||||
# because of this is the default behavior
|
||||
|
||||
def test_fail_on_conn_true(self, m_ansible_module, monkeypatch, m_psycopg2):
|
||||
"""
|
||||
Test connect_to_db(), fail_on_conn arg passed as True (the default behavior).
|
||||
|
@ -242,7 +253,8 @@ class TestConnectToDb():
|
|||
|
||||
m_ansible_module.params['login_user'] = 'Exception' # causes Exception
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=True)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=True)
|
||||
|
||||
assert 'unable to connect to database' in m_ansible_module.err_msg
|
||||
assert db_connection is None
|
||||
|
@ -256,7 +268,8 @@ class TestConnectToDb():
|
|||
|
||||
m_ansible_module.params['login_user'] = 'Exception' # causes Exception
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=False)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=False)
|
||||
|
||||
assert m_ansible_module.err_msg == ''
|
||||
assert 'PostgreSQL server is unavailable' in m_ansible_module.warn_msg
|
||||
|
@ -271,7 +284,8 @@ class TestConnectToDb():
|
|||
# case 1: psycopg2.__version >= 2.4.2 (the default in m_psycopg2)
|
||||
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module, autocommit=True)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor()
|
||||
|
||||
# if errors, db_connection returned as None:
|
||||
|
@ -283,10 +297,26 @@ class TestConnectToDb():
|
|||
m_psycopg2.__version__ = '2.4.1'
|
||||
monkeypatch.setattr(pg, 'psycopg2', m_psycopg2)
|
||||
|
||||
db_connection = pg.connect_to_db(m_ansible_module, autocommit=True)
|
||||
conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params)
|
||||
db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True)
|
||||
cursor = db_connection.cursor()
|
||||
|
||||
# if errors, db_connection returned as None:
|
||||
assert isinstance(db_connection, DbConnection)
|
||||
assert isinstance(cursor, Cursor)
|
||||
assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg
|
||||
|
||||
|
||||
class TestGetConnParams():
|
||||
|
||||
"""Namespace for testing get_conn_params() function."""
|
||||
|
||||
def test_get_conn_params_def(self, m_ansible_module):
|
||||
"""Test get_conn_params(), warn_db_default kwarg is default."""
|
||||
assert pg.get_conn_params(m_ansible_module, INPUT_DICT) == EXPECTED_DICT
|
||||
assert m_ansible_module.warn_msg == 'Database name has not been passed, used default database to connect to.'
|
||||
|
||||
def test_get_conn_params_warn_db_def_false(self, m_ansible_module):
|
||||
"""Test get_conn_params(), warn_db_default kwarg is False."""
|
||||
assert pg.get_conn_params(m_ansible_module, INPUT_DICT, warn_db_default=False) == EXPECTED_DICT
|
||||
assert m_ansible_module.warn_msg == ''
|
||||
|
|
Loading…
Reference in a new issue