Escape mysql identifiers
This commit is contained in:
parent
87b2afc272
commit
bed60553ca
3 changed files with 64 additions and 23 deletions
|
@ -118,7 +118,7 @@ def db_exists(cursor, db):
|
||||||
return bool(res)
|
return bool(res)
|
||||||
|
|
||||||
def db_delete(cursor, db):
|
def db_delete(cursor, db):
|
||||||
query = "DROP DATABASE `%s`" % db
|
query = "DROP DATABASE %s" % mysql_quote_identifier(db, 'database')
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -190,12 +190,14 @@ def db_import(module, host, user, password, db_name, target, port, socket=None):
|
||||||
return rc, stdout, stderr
|
return rc, stdout, stderr
|
||||||
|
|
||||||
def db_create(cursor, db, encoding, collation):
|
def db_create(cursor, db, encoding, collation):
|
||||||
|
query_params = dict(enc=encoding, collate=collation)
|
||||||
|
query = ['CREATE DATABASE %s' % mysql_quote_identifier(db, 'database')]
|
||||||
if encoding:
|
if encoding:
|
||||||
encoding = " CHARACTER SET %s" % encoding
|
query.append("CHARACTER SET %(enc)s")
|
||||||
if collation:
|
if collation:
|
||||||
collation = " COLLATE %s" % collation
|
query.append("COLLATE %(collate)s")
|
||||||
query = "CREATE DATABASE `%s`%s%s" % (db, encoding, collation)
|
query = ' '.join(query)
|
||||||
res = cursor.execute(query)
|
res = cursor.execute(query, query_params)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def strip_quotes(s):
|
def strip_quotes(s):
|
||||||
|
@ -360,4 +362,6 @@ def main():
|
||||||
|
|
||||||
# import module snippets
|
# import module snippets
|
||||||
from ansible.module_utils.basic import *
|
from ansible.module_utils.basic import *
|
||||||
main()
|
from ansible.module_utils.database import *
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
|
@ -151,6 +151,19 @@ except ImportError:
|
||||||
else:
|
else:
|
||||||
mysqldb_found = True
|
mysqldb_found = True
|
||||||
|
|
||||||
|
VALID_PRIVS = frozenset(('CREATE', 'DROP', 'GRANT OPTION', 'LOCK TABLES',
|
||||||
|
'REFERENCES', 'EVENT', 'ALTER', 'DELETE', 'INDEX',
|
||||||
|
'INSERT', 'SELECT', 'UPDATE',
|
||||||
|
'CREATE TEMPORARY TABLES', 'TRIGGER', 'CREATE VIEW',
|
||||||
|
'SHOW VIEW', 'ALTER ROUTINE', 'CREATE ROUTINE',
|
||||||
|
'EXECUTE', 'FILE', 'CREATE USER', 'PROCESS', 'RELOAD',
|
||||||
|
'REPLICATION CLIENT', 'REPLICATION SLAVE',
|
||||||
|
'SHOW DATABASES', 'SHUTDOWN', 'SUPER', 'ALL',
|
||||||
|
'ALL PRIVILEGES', 'USAGE',))
|
||||||
|
|
||||||
|
class InvalidPrivsError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
# ===========================================
|
# ===========================================
|
||||||
# MySQL module specific support methods.
|
# MySQL module specific support methods.
|
||||||
#
|
#
|
||||||
|
@ -217,7 +230,7 @@ def user_mod(cursor, user, host, password, new_priv, append_privs):
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def user_delete(cursor, user, host):
|
def user_delete(cursor, user, host):
|
||||||
cursor.execute("DROP USER %s@%s", (user,host))
|
cursor.execute("DROP USER %s@%s", (user, host))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def privileges_get(cursor, user,host):
|
def privileges_get(cursor, user,host):
|
||||||
|
@ -231,7 +244,7 @@ def privileges_get(cursor, user,host):
|
||||||
The dictionary format is the same as that returned by privileges_unpack() below.
|
The dictionary format is the same as that returned by privileges_unpack() below.
|
||||||
"""
|
"""
|
||||||
output = {}
|
output = {}
|
||||||
cursor.execute("SHOW GRANTS FOR %s@%s", (user,host))
|
cursor.execute("SHOW GRANTS FOR %s@%s", (user, host))
|
||||||
grants = cursor.fetchall()
|
grants = cursor.fetchall()
|
||||||
|
|
||||||
def pick(x):
|
def pick(x):
|
||||||
|
@ -274,6 +287,9 @@ def privileges_unpack(priv):
|
||||||
pieces[0] = '.'.join(pieces[0])
|
pieces[0] = '.'.join(pieces[0])
|
||||||
|
|
||||||
output[pieces[0]] = pieces[1].upper().split(',')
|
output[pieces[0]] = pieces[1].upper().split(',')
|
||||||
|
new_privs = frozenset(output[pieces[0]])
|
||||||
|
if not new_privs.issubset(VALID_PRIVS):
|
||||||
|
raise InvalidPrivsError('Invalid privileges specified: %s' % new_privs.difference(VALID_PRIVS))
|
||||||
|
|
||||||
if '*.*' not in output:
|
if '*.*' not in output:
|
||||||
output['*.*'] = ['USAGE']
|
output['*.*'] = ['USAGE']
|
||||||
|
@ -282,18 +298,24 @@ def privileges_unpack(priv):
|
||||||
|
|
||||||
def privileges_revoke(cursor, user,host,db_table,grant_option):
|
def privileges_revoke(cursor, user,host,db_table,grant_option):
|
||||||
if grant_option:
|
if grant_option:
|
||||||
query = "REVOKE GRANT OPTION ON %s FROM '%s'@'%s'" % (db_table,user,host)
|
query = ["REVOKE GRANT OPTION ON %s" % mysql_quote_identifier(db_table, 'table')]
|
||||||
cursor.execute(query)
|
query.append("FROM %s@%s")
|
||||||
query = "REVOKE ALL PRIVILEGES ON %s FROM '%s'@'%s'" % (db_table,user,host)
|
query = ' '.join(query)
|
||||||
cursor.execute(query)
|
cursor.execute(query, (user, host))
|
||||||
|
query = ["REVOKE ALL PRIVILEGES ON %s" % mysql_quote_identifier(db_table, 'table')]
|
||||||
|
query.append("FROM %s@%s")
|
||||||
|
query = ' '.join(query)
|
||||||
|
cursor.execute(query, (user, host))
|
||||||
|
|
||||||
def privileges_grant(cursor, user,host,db_table,priv):
|
def privileges_grant(cursor, user,host,db_table,priv):
|
||||||
|
|
||||||
priv_string = ",".join(filter(lambda x: x != 'GRANT', priv))
|
priv_string = ",".join(filter(lambda x: x != 'GRANT', priv))
|
||||||
query = "GRANT %s ON %s TO '%s'@'%s'" % (priv_string,db_table,user,host)
|
query = ["GRANT %s ON %s" % (priv_string, mysql_quote_identifier(db_table, 'table'))]
|
||||||
|
query.append("TO %s@%s")
|
||||||
if 'GRANT' in priv:
|
if 'GRANT' in priv:
|
||||||
query = query + " WITH GRANT OPTION"
|
query.append("WITH GRANT OPTION")
|
||||||
cursor.execute(query)
|
query = ' '.join(query)
|
||||||
|
cursor.execute(query, (user, host))
|
||||||
|
|
||||||
|
|
||||||
def strip_quotes(s):
|
def strip_quotes(s):
|
||||||
|
@ -425,8 +447,8 @@ def main():
|
||||||
if priv is not None:
|
if priv is not None:
|
||||||
try:
|
try:
|
||||||
priv = privileges_unpack(priv)
|
priv = privileges_unpack(priv)
|
||||||
except:
|
except Exception, e:
|
||||||
module.fail_json(msg="invalid privileges string")
|
module.fail_json(msg="invalid privileges string: %s" % str(e))
|
||||||
|
|
||||||
# Either the caller passes both a username and password with which to connect to
|
# Either the caller passes both a username and password with which to connect to
|
||||||
# mysql, or they pass neither and allow this module to read the credentials from
|
# mysql, or they pass neither and allow this module to read the credentials from
|
||||||
|
@ -459,11 +481,17 @@ def main():
|
||||||
|
|
||||||
if state == "present":
|
if state == "present":
|
||||||
if user_exists(cursor, user, host):
|
if user_exists(cursor, user, host):
|
||||||
changed = user_mod(cursor, user, host, password, priv, append_privs)
|
try:
|
||||||
|
changed = user_mod(cursor, user, host, password, priv, append_privs)
|
||||||
|
except SQLParseError, e:
|
||||||
|
module.fail_json(msg=str(e))
|
||||||
else:
|
else:
|
||||||
if password is None:
|
if password is None:
|
||||||
module.fail_json(msg="password parameter required when adding a user")
|
module.fail_json(msg="password parameter required when adding a user")
|
||||||
changed = user_add(cursor, user, host, password, priv)
|
try:
|
||||||
|
changed = user_add(cursor, user, host, password, priv)
|
||||||
|
except SQLParseError, e:
|
||||||
|
module.fail_json(msg=str(e))
|
||||||
elif state == "absent":
|
elif state == "absent":
|
||||||
if user_exists(cursor, user, host):
|
if user_exists(cursor, user, host):
|
||||||
changed = user_delete(cursor, user, host)
|
changed = user_delete(cursor, user, host)
|
||||||
|
@ -473,4 +501,6 @@ def main():
|
||||||
|
|
||||||
# import module snippets
|
# import module snippets
|
||||||
from ansible.module_utils.basic import *
|
from ansible.module_utils.basic import *
|
||||||
main()
|
from ansible.module_utils.database import *
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
|
@ -103,7 +103,7 @@ def typedvalue(value):
|
||||||
|
|
||||||
|
|
||||||
def getvariable(cursor, mysqlvar):
|
def getvariable(cursor, mysqlvar):
|
||||||
cursor.execute("SHOW VARIABLES LIKE '" + mysqlvar + "'")
|
cursor.execute("SHOW VARIABLES LIKE %s", (mysqlvar,))
|
||||||
mysqlvar_val = cursor.fetchall()
|
mysqlvar_val = cursor.fetchall()
|
||||||
return mysqlvar_val
|
return mysqlvar_val
|
||||||
|
|
||||||
|
@ -116,8 +116,11 @@ def setvariable(cursor, mysqlvar, value):
|
||||||
should be passed as numeric literals.
|
should be passed as numeric literals.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
query = ["SET GLOBAL %s" % mysql_quote_identifier(mysqlvar, 'vars') ]
|
||||||
|
query.append(" = %s")
|
||||||
|
query = ' '.join(query)
|
||||||
try:
|
try:
|
||||||
cursor.execute("SET GLOBAL " + mysqlvar + " = %s", (value,))
|
cursor.execute(query, (value,))
|
||||||
cursor.fetchall()
|
cursor.fetchall()
|
||||||
result = True
|
result = True
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
|
@ -242,7 +245,10 @@ def main():
|
||||||
value_actual = typedvalue(mysqlvar_val[0][1])
|
value_actual = typedvalue(mysqlvar_val[0][1])
|
||||||
if value_wanted == value_actual:
|
if value_wanted == value_actual:
|
||||||
module.exit_json(msg="Variable already set to requested value", changed=False)
|
module.exit_json(msg="Variable already set to requested value", changed=False)
|
||||||
result = setvariable(cursor, mysqlvar, value_wanted)
|
try:
|
||||||
|
result = setvariable(cursor, mysqlvar, value_wanted)
|
||||||
|
except SQLParseError, e:
|
||||||
|
result = str(e)
|
||||||
if result is True:
|
if result is True:
|
||||||
module.exit_json(msg="Variable change succeeded prev_value=%s" % value_actual, changed=True)
|
module.exit_json(msg="Variable change succeeded prev_value=%s" % value_actual, changed=True)
|
||||||
else:
|
else:
|
||||||
|
@ -250,4 +256,5 @@ def main():
|
||||||
|
|
||||||
# import module snippets
|
# import module snippets
|
||||||
from ansible.module_utils.basic import *
|
from ansible.module_utils.basic import *
|
||||||
|
from ansible.module_utils.database import *
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue