Pep8 fixes for mysql module (#23923)

Signed-off-by: Abhijeet Kasurde <akasurde@redhat.com>
This commit is contained in:
Abhijeet Kasurde 2017-04-26 17:26:35 +05:30 committed by John R Barker
parent 9d5c399313
commit 9fbbb5e10f
5 changed files with 63 additions and 42 deletions

View file

@ -128,15 +128,18 @@ else:
# MySQL module specific support methods. # MySQL module specific support methods.
# #
def db_exists(cursor, db): def db_exists(cursor, db):
res = cursor.execute("SHOW DATABASES LIKE %s", (db.replace("_", "\_"),)) res = cursor.execute("SHOW DATABASES LIKE %s", (db.replace("_", "\_"),))
return bool(res) return bool(res)
def db_delete(cursor, db): def db_delete(cursor, db):
query = "DROP DATABASE %s" % mysql_quote_identifier(db, 'database') query = "DROP DATABASE %s" % mysql_quote_identifier(db, 'database')
cursor.execute(query) cursor.execute(query)
return True return True
def db_dump(module, host, user, password, db_name, target, all_databases, port, config_file, socket=None, ssl_cert=None, ssl_key=None, ssl_ca=None, def db_dump(module, host, user, password, db_name, target, all_databases, port, config_file, socket=None, ssl_cert=None, ssl_key=None, ssl_ca=None,
single_transaction=None, quick=None): single_transaction=None, quick=None):
cmd = module.get_bin_path('mysqldump', True) cmd = module.get_bin_path('mysqldump', True)
@ -182,6 +185,7 @@ def db_dump(module, host, user, password, db_name, target, all_databases, port,
rc, stdout, stderr = module.run_command(cmd, use_unsafe_shell=True) rc, stdout, stderr = module.run_command(cmd, use_unsafe_shell=True)
return rc, stdout, stderr return rc, stdout, stderr
def db_import(module, host, user, password, db_name, target, all_databases, port, config_file, socket=None, ssl_cert=None, ssl_key=None, ssl_ca=None): def db_import(module, host, user, password, db_name, target, all_databases, port, config_file, socket=None, ssl_cert=None, ssl_key=None, ssl_ca=None):
if not os.path.exists(target): if not os.path.exists(target):
return module.fail_json(msg="target %s does not exist on the host" % target) return module.fail_json(msg="target %s does not exist on the host" % target)
@ -234,6 +238,7 @@ def db_import(module, host, user, password, db_name, target, all_databases, port
rc, stdout, stderr = module.run_command(cmd, use_unsafe_shell=True) rc, stdout, stderr = module.run_command(cmd, use_unsafe_shell=True)
return rc, stdout, stderr return rc, stdout, stderr
def db_create(cursor, db, encoding, collation): def db_create(cursor, db, encoding, collation):
query_params = dict(enc=encoding, collate=collation) query_params = dict(enc=encoding, collate=collation)
query = ['CREATE DATABASE %s' % mysql_quote_identifier(db, 'database')] query = ['CREATE DATABASE %s' % mysql_quote_identifier(db, 'database')]
@ -242,13 +247,14 @@ def db_create(cursor, db, encoding, collation):
if collation: if collation:
query.append("COLLATE %(collate)s") query.append("COLLATE %(collate)s")
query = ' '.join(query) query = ' '.join(query)
res = cursor.execute(query, query_params) cursor.execute(query, query_params)
return True return True
# =========================================== # ===========================================
# Module execution. # Module execution.
# #
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
@ -298,7 +304,7 @@ def main():
if state in ['dump', 'import']: if state in ['dump', 'import']:
if target is None: if target is None:
module.fail_json(msg="with state=%s target is required" % (state)) module.fail_json(msg="with state=%s target is required" % state)
if db == 'all': if db == 'all':
db = 'mysql' db = 'mysql'
all_databases = True all_databases = True
@ -339,7 +345,8 @@ def main():
else: else:
rc, stdout, stderr = db_dump(module, login_host, login_user, rc, stdout, stderr = db_dump(module, login_host, login_user,
login_password, db, target, all_databases, login_password, db, target, all_databases,
login_port, config_file, socket, ssl_cert, ssl_key, ssl_ca, single_transaction, quick) login_port, config_file, socket, ssl_cert, ssl_key,
ssl_ca, single_transaction, quick)
if rc != 0: if rc != 0:
module.fail_json(msg="%s" % stderr) module.fail_json(msg="%s" % stderr)
else: else:
@ -350,8 +357,10 @@ def main():
module.exit_json(changed=True, db=db) module.exit_json(changed=True, db=db)
else: else:
rc, stdout, stderr = db_import(module, login_host, login_user, rc, stdout, stderr = db_import(module, login_host, login_user,
login_password, db, target, all_databases, login_password, db, target,
login_port, config_file, socket, ssl_cert, ssl_key, ssl_ca) all_databases,
login_port, config_file,
socket, ssl_cert, ssl_key, ssl_ca)
if rc != 0: if rc != 0:
module.fail_json(msg="%s" % stderr) module.fail_json(msg="%s" % stderr)
else: else:

View file

@ -214,8 +214,7 @@ EXAMPLES = """
# password=n<_665{vS43y # password=n<_665{vS43y
""" """
import getpass
import tempfile
import re import re
import string import string
try: try:
@ -236,6 +235,7 @@ VALID_PRIVS = frozenset(('CREATE', 'DROP', 'GRANT', 'GRANT OPTION',
'REPLICATION SLAVE', 'SHOW DATABASES', 'SHUTDOWN', 'REPLICATION SLAVE', 'SHOW DATABASES', 'SHUTDOWN',
'SUPER', 'ALL', 'ALL PRIVILEGES', 'USAGE', 'REQUIRESSL')) 'SUPER', 'ALL', 'ALL PRIVILEGES', 'USAGE', 'REQUIRESSL'))
class InvalidPrivsError(Exception): class InvalidPrivsError(Exception):
pass pass
@ -243,6 +243,7 @@ class InvalidPrivsError(Exception):
# MySQL module specific support methods. # MySQL module specific support methods.
# #
# User Authentication Management was change in MySQL 5.7 # User Authentication Management was change in MySQL 5.7
# This is a generic check for if the server version is less than version 5.7 # This is a generic check for if the server version is less than version 5.7
def server_version_check(cursor): def server_version_check(cursor):
@ -255,11 +256,12 @@ def server_version_check(cursor):
# mariadb and the old-style update continues to work # mariadb and the old-style update continues to work
if 'mariadb' in version_str.lower(): if 'mariadb' in version_str.lower():
return True return True
if (int(version[0]) <= 5 and int(version[1]) < 7): if int(version[0]) <= 5 and int(version[1]) < 7:
return True return True
else: else:
return False return False
def get_mode(cursor): def get_mode(cursor):
cursor.execute('SELECT @@GLOBAL.sql_mode') cursor.execute('SELECT @@GLOBAL.sql_mode')
result = cursor.fetchone() result = cursor.fetchone()
@ -270,6 +272,7 @@ def get_mode(cursor):
mode = 'NOTANSI' mode = 'NOTANSI'
return mode return mode
def user_exists(cursor, user, host, host_all): def user_exists(cursor, user, host, host_all):
if host_all: if host_all:
cursor.execute("SELECT count(*) FROM user WHERE user = %s", ([user])) cursor.execute("SELECT count(*) FROM user WHERE user = %s", ([user]))
@ -279,6 +282,7 @@ def user_exists(cursor, user, host, host_all):
count = cursor.fetchone() count = cursor.fetchone()
return count[0] > 0 return count[0] > 0
def user_add(cursor, user, host, host_all, password, encrypted, new_priv, check_mode): def user_add(cursor, user, host, host_all, password, encrypted, new_priv, check_mode):
# we cannot create users without a proper hostname # we cannot create users without a proper hostname
if host_all: if host_all:
@ -298,6 +302,7 @@ def user_add(cursor, user, host, host_all, password, encrypted, new_priv, check_
privileges_grant(cursor, user, host, db_table, priv) privileges_grant(cursor, user, host, db_table, priv)
return True return True
def is_hash(password): def is_hash(password):
ishash = False ishash = False
if len(password) == 41 and password[0] == '*': if len(password) == 41 and password[0] == '*':
@ -305,6 +310,7 @@ def is_hash(password):
ishash = True ishash = True
return ishash return ishash
def user_mod(cursor, user, host, host_all, password, encrypted, new_priv, append_privs, module): def user_mod(cursor, user, host, host_all, password, encrypted, new_priv, append_privs, module):
changed = False changed = False
grant_option = False grant_option = False
@ -385,7 +391,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, new_priv, append
db_table_intersect = set(new_priv.keys()) & set(curr_priv.keys()) db_table_intersect = set(new_priv.keys()) & set(curr_priv.keys())
for db_table in db_table_intersect: for db_table in db_table_intersect:
priv_diff = set(new_priv[db_table]) ^ set(curr_priv[db_table]) priv_diff = set(new_priv[db_table]) ^ set(curr_priv[db_table])
if (len(priv_diff) > 0): if len(priv_diff) > 0:
if module.check_mode: if module.check_mode:
return True return True
if not append_privs: if not append_privs:
@ -395,6 +401,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted, new_priv, append
return changed return changed
def user_delete(cursor, user, host, host_all, check_mode): def user_delete(cursor, user, host, host_all, check_mode):
if check_mode: if check_mode:
return True return True
@ -409,6 +416,7 @@ def user_delete(cursor, user, host, host_all, check_mode):
return True return True
def user_get_hostnames(cursor, user): def user_get_hostnames(cursor, user):
cursor.execute("SELECT Host FROM mysql.user WHERE user = %s", user) cursor.execute("SELECT Host FROM mysql.user WHERE user = %s", user)
hostnames_raw = cursor.fetchall() hostnames_raw = cursor.fetchall()
@ -419,6 +427,7 @@ def user_get_hostnames(cursor, user):
return hostnames return hostnames
def privileges_get(cursor, user, host): def privileges_get(cursor, user, host):
""" MySQL doesn't have a better method of getting privileges aside from the """ MySQL doesn't have a better method of getting privileges aside from the
SHOW GRANTS query syntax, which requires us to then parse the returned string. SHOW GRANTS query syntax, which requires us to then parse the returned string.
@ -453,6 +462,7 @@ def privileges_get(cursor, user,host):
output[db] = privileges output[db] = privileges
return output return output
def privileges_unpack(priv, mode): def privileges_unpack(priv, mode):
""" Take a privileges string, typically passed as a parameter, and unserialize """ Take a privileges string, typically passed as a parameter, and unserialize
it into a dictionary, the same format as privileges_get() above. We have this it into a dictionary, the same format as privileges_get() above. We have this
@ -501,6 +511,7 @@ def privileges_unpack(priv, mode):
return output return output
def privileges_revoke(cursor, user, host, db_table, priv, grant_option): def privileges_revoke(cursor, user, host, db_table, priv, grant_option):
# Escape '%' since mysql db.execute() uses a format string # Escape '%' since mysql db.execute() uses a format string
db_table = db_table.replace('%', '%%') db_table = db_table.replace('%', '%%')
@ -515,6 +526,7 @@ def privileges_revoke(cursor, user,host,db_table,priv,grant_option):
query = ' '.join(query) query = ' '.join(query)
cursor.execute(query, (user, host)) cursor.execute(query, (user, host))
def privileges_grant(cursor, user, host, db_table, priv): def privileges_grant(cursor, user, host, db_table, priv):
# Escape '%' since mysql db.execute uses a format string and the # Escape '%' since mysql db.execute uses a format string and the
# specification of db and table often use a % (SQL wildcard) # specification of db and table often use a % (SQL wildcard)
@ -533,6 +545,7 @@ def privileges_grant(cursor, user,host,db_table,priv):
# Module execution. # Module execution.
# #
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(
@ -645,5 +658,6 @@ def main():
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
from ansible.module_utils.database import * from ansible.module_utils.database import *
from ansible.module_utils.mysql import * from ansible.module_utils.mysql import *
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View file

@ -104,6 +104,7 @@ def getvariable(cursor, mysqlvar):
else: else:
return None return None
def setvariable(cursor, mysqlvar, value): def setvariable(cursor, mysqlvar, value):
""" Set a global mysql variable to a given value """ Set a global mysql variable to a given value
@ -122,6 +123,7 @@ def setvariable(cursor, mysqlvar, value):
result = str(e) result = str(e)
return result return result
def main(): def main():
module = AnsibleModule( module = AnsibleModule(
argument_spec=dict( argument_spec=dict(

View file

@ -394,10 +394,6 @@ lib/ansible/modules/database/misc/riak.py
lib/ansible/modules/database/mongodb/mongodb_parameter.py lib/ansible/modules/database/mongodb/mongodb_parameter.py
lib/ansible/modules/database/mongodb/mongodb_user.py lib/ansible/modules/database/mongodb/mongodb_user.py
lib/ansible/modules/database/mssql/mssql_db.py lib/ansible/modules/database/mssql/mssql_db.py
lib/ansible/modules/database/mysql/mysql_db.py
lib/ansible/modules/database/mysql/mysql_replication.py
lib/ansible/modules/database/mysql/mysql_user.py
lib/ansible/modules/database/mysql/mysql_variables.py
lib/ansible/modules/database/postgresql/postgresql_db.py lib/ansible/modules/database/postgresql/postgresql_db.py
lib/ansible/modules/database/postgresql/postgresql_ext.py lib/ansible/modules/database/postgresql/postgresql_ext.py
lib/ansible/modules/database/postgresql/postgresql_lang.py lib/ansible/modules/database/postgresql/postgresql_lang.py