Escape mysql identifiers

This commit is contained in:
Toshio Kuratomi 2014-11-25 01:46:09 -08:00 committed by Matt Clay
parent 87b2afc272
commit bed60553ca
3 changed files with 64 additions and 23 deletions

View file

@ -118,7 +118,7 @@ def db_exists(cursor, db):
return bool(res)
def db_delete(cursor, db):
query = "DROP DATABASE `%s`" % db
query = "DROP DATABASE %s" % mysql_quote_identifier(db, 'database')
cursor.execute(query)
return True
@ -190,12 +190,14 @@ def db_import(module, host, user, password, db_name, target, port, socket=None):
return rc, stdout, stderr
def db_create(cursor, db, encoding, collation):
query_params = dict(enc=encoding, collate=collation)
query = ['CREATE DATABASE %s' % mysql_quote_identifier(db, 'database')]
if encoding:
encoding = " CHARACTER SET %s" % encoding
query.append("CHARACTER SET %(enc)s")
if collation:
collation = " COLLATE %s" % collation
query = "CREATE DATABASE `%s`%s%s" % (db, encoding, collation)
res = cursor.execute(query)
query.append("COLLATE %(collate)s")
query = ' '.join(query)
res = cursor.execute(query, query_params)
return True
def strip_quotes(s):
@ -360,4 +362,6 @@ def main():
# import module snippets
from ansible.module_utils.basic import *
main()
from ansible.module_utils.database import *
if __name__ == '__main__':
main()

View file

@ -151,6 +151,19 @@ except ImportError:
else:
mysqldb_found = True
VALID_PRIVS = frozenset(('CREATE', 'DROP', 'GRANT OPTION', 'LOCK TABLES',
'REFERENCES', 'EVENT', 'ALTER', 'DELETE', 'INDEX',
'INSERT', 'SELECT', 'UPDATE',
'CREATE TEMPORARY TABLES', 'TRIGGER', 'CREATE VIEW',
'SHOW VIEW', 'ALTER ROUTINE', 'CREATE ROUTINE',
'EXECUTE', 'FILE', 'CREATE USER', 'PROCESS', 'RELOAD',
'REPLICATION CLIENT', 'REPLICATION SLAVE',
'SHOW DATABASES', 'SHUTDOWN', 'SUPER', 'ALL',
'ALL PRIVILEGES', 'USAGE',))
class InvalidPrivsError(Exception):
pass
# ===========================================
# MySQL module specific support methods.
#
@ -217,7 +230,7 @@ def user_mod(cursor, user, host, password, new_priv, append_privs):
return changed
def user_delete(cursor, user, host):
cursor.execute("DROP USER %s@%s", (user,host))
cursor.execute("DROP USER %s@%s", (user, host))
return True
def privileges_get(cursor, user,host):
@ -231,7 +244,7 @@ def privileges_get(cursor, user,host):
The dictionary format is the same as that returned by privileges_unpack() below.
"""
output = {}
cursor.execute("SHOW GRANTS FOR %s@%s", (user,host))
cursor.execute("SHOW GRANTS FOR %s@%s", (user, host))
grants = cursor.fetchall()
def pick(x):
@ -274,6 +287,9 @@ def privileges_unpack(priv):
pieces[0] = '.'.join(pieces[0])
output[pieces[0]] = pieces[1].upper().split(',')
new_privs = frozenset(output[pieces[0]])
if not new_privs.issubset(VALID_PRIVS):
raise InvalidPrivsError('Invalid privileges specified: %s' % new_privs.difference(VALID_PRIVS))
if '*.*' not in output:
output['*.*'] = ['USAGE']
@ -282,18 +298,24 @@ def privileges_unpack(priv):
def privileges_revoke(cursor, user,host,db_table,grant_option):
if grant_option:
query = "REVOKE GRANT OPTION ON %s FROM '%s'@'%s'" % (db_table,user,host)
cursor.execute(query)
query = "REVOKE ALL PRIVILEGES ON %s FROM '%s'@'%s'" % (db_table,user,host)
cursor.execute(query)
query = ["REVOKE GRANT OPTION ON %s" % mysql_quote_identifier(db_table, 'table')]
query.append("FROM %s@%s")
query = ' '.join(query)
cursor.execute(query, (user, host))
query = ["REVOKE ALL PRIVILEGES ON %s" % mysql_quote_identifier(db_table, 'table')]
query.append("FROM %s@%s")
query = ' '.join(query)
cursor.execute(query, (user, host))
def privileges_grant(cursor, user,host,db_table,priv):
priv_string = ",".join(filter(lambda x: x != 'GRANT', priv))
query = "GRANT %s ON %s TO '%s'@'%s'" % (priv_string,db_table,user,host)
query = ["GRANT %s ON %s" % (priv_string, mysql_quote_identifier(db_table, 'table'))]
query.append("TO %s@%s")
if 'GRANT' in priv:
query = query + " WITH GRANT OPTION"
cursor.execute(query)
query.append("WITH GRANT OPTION")
query = ' '.join(query)
cursor.execute(query, (user, host))
def strip_quotes(s):
@ -425,8 +447,8 @@ def main():
if priv is not None:
try:
priv = privileges_unpack(priv)
except:
module.fail_json(msg="invalid privileges string")
except Exception, e:
module.fail_json(msg="invalid privileges string: %s" % str(e))
# Either the caller passes both a username and password with which to connect to
# mysql, or they pass neither and allow this module to read the credentials from
@ -459,11 +481,17 @@ def main():
if state == "present":
if user_exists(cursor, user, host):
changed = user_mod(cursor, user, host, password, priv, append_privs)
try:
changed = user_mod(cursor, user, host, password, priv, append_privs)
except SQLParseError, e:
module.fail_json(msg=str(e))
else:
if password is None:
module.fail_json(msg="password parameter required when adding a user")
changed = user_add(cursor, user, host, password, priv)
try:
changed = user_add(cursor, user, host, password, priv)
except SQLParseError, e:
module.fail_json(msg=str(e))
elif state == "absent":
if user_exists(cursor, user, host):
changed = user_delete(cursor, user, host)
@ -473,4 +501,6 @@ def main():
# import module snippets
from ansible.module_utils.basic import *
main()
from ansible.module_utils.database import *
if __name__ == '__main__':
main()

View file

@ -103,7 +103,7 @@ def typedvalue(value):
def getvariable(cursor, mysqlvar):
cursor.execute("SHOW VARIABLES LIKE '" + mysqlvar + "'")
cursor.execute("SHOW VARIABLES LIKE %s", (mysqlvar,))
mysqlvar_val = cursor.fetchall()
return mysqlvar_val
@ -116,8 +116,11 @@ def setvariable(cursor, mysqlvar, value):
should be passed as numeric literals.
"""
query = ["SET GLOBAL %s" % mysql_quote_identifier(mysqlvar, 'vars') ]
query.append(" = %s")
query = ' '.join(query)
try:
cursor.execute("SET GLOBAL " + mysqlvar + " = %s", (value,))
cursor.execute(query, (value,))
cursor.fetchall()
result = True
except Exception, e:
@ -242,7 +245,10 @@ def main():
value_actual = typedvalue(mysqlvar_val[0][1])
if value_wanted == value_actual:
module.exit_json(msg="Variable already set to requested value", changed=False)
result = setvariable(cursor, mysqlvar, value_wanted)
try:
result = setvariable(cursor, mysqlvar, value_wanted)
except SQLParseError, e:
result = str(e)
if result is True:
module.exit_json(msg="Variable change succeeded prev_value=%s" % value_actual, changed=True)
else:
@ -250,4 +256,5 @@ def main():
# import module snippets
from ansible.module_utils.basic import *
from ansible.module_utils.database import *
main()