Escape mysql identifiers

This commit is contained in:
Toshio Kuratomi 2014-11-25 01:46:09 -08:00 committed by Matt Clay
parent 87b2afc272
commit bed60553ca
3 changed files with 64 additions and 23 deletions

View file

@ -118,7 +118,7 @@ def db_exists(cursor, db):
return bool(res) return bool(res)
def db_delete(cursor, db): def db_delete(cursor, db):
query = "DROP DATABASE `%s`" % db query = "DROP DATABASE %s" % mysql_quote_identifier(db, 'database')
cursor.execute(query) cursor.execute(query)
return True return True
@ -190,12 +190,14 @@ def db_import(module, host, user, password, db_name, target, port, socket=None):
return rc, stdout, stderr return rc, stdout, stderr
def db_create(cursor, db, encoding, collation): def db_create(cursor, db, encoding, collation):
query_params = dict(enc=encoding, collate=collation)
query = ['CREATE DATABASE %s' % mysql_quote_identifier(db, 'database')]
if encoding: if encoding:
encoding = " CHARACTER SET %s" % encoding query.append("CHARACTER SET %(enc)s")
if collation: if collation:
collation = " COLLATE %s" % collation query.append("COLLATE %(collate)s")
query = "CREATE DATABASE `%s`%s%s" % (db, encoding, collation) query = ' '.join(query)
res = cursor.execute(query) res = cursor.execute(query, query_params)
return True return True
def strip_quotes(s): def strip_quotes(s):
@ -360,4 +362,6 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
from ansible.module_utils.database import *
if __name__ == '__main__':
main() main()

View file

@ -151,6 +151,19 @@ except ImportError:
else: else:
mysqldb_found = True mysqldb_found = True
VALID_PRIVS = frozenset(('CREATE', 'DROP', 'GRANT OPTION', 'LOCK TABLES',
'REFERENCES', 'EVENT', 'ALTER', 'DELETE', 'INDEX',
'INSERT', 'SELECT', 'UPDATE',
'CREATE TEMPORARY TABLES', 'TRIGGER', 'CREATE VIEW',
'SHOW VIEW', 'ALTER ROUTINE', 'CREATE ROUTINE',
'EXECUTE', 'FILE', 'CREATE USER', 'PROCESS', 'RELOAD',
'REPLICATION CLIENT', 'REPLICATION SLAVE',
'SHOW DATABASES', 'SHUTDOWN', 'SUPER', 'ALL',
'ALL PRIVILEGES', 'USAGE',))
class InvalidPrivsError(Exception):
pass
# =========================================== # ===========================================
# MySQL module specific support methods. # MySQL module specific support methods.
# #
@ -274,6 +287,9 @@ def privileges_unpack(priv):
pieces[0] = '.'.join(pieces[0]) pieces[0] = '.'.join(pieces[0])
output[pieces[0]] = pieces[1].upper().split(',') output[pieces[0]] = pieces[1].upper().split(',')
new_privs = frozenset(output[pieces[0]])
if not new_privs.issubset(VALID_PRIVS):
raise InvalidPrivsError('Invalid privileges specified: %s' % new_privs.difference(VALID_PRIVS))
if '*.*' not in output: if '*.*' not in output:
output['*.*'] = ['USAGE'] output['*.*'] = ['USAGE']
@ -282,18 +298,24 @@ def privileges_unpack(priv):
def privileges_revoke(cursor, user,host,db_table,grant_option): def privileges_revoke(cursor, user,host,db_table,grant_option):
if grant_option: if grant_option:
query = "REVOKE GRANT OPTION ON %s FROM '%s'@'%s'" % (db_table,user,host) query = ["REVOKE GRANT OPTION ON %s" % mysql_quote_identifier(db_table, 'table')]
cursor.execute(query) query.append("FROM %s@%s")
query = "REVOKE ALL PRIVILEGES ON %s FROM '%s'@'%s'" % (db_table,user,host) query = ' '.join(query)
cursor.execute(query) cursor.execute(query, (user, host))
query = ["REVOKE ALL PRIVILEGES ON %s" % mysql_quote_identifier(db_table, 'table')]
query.append("FROM %s@%s")
query = ' '.join(query)
cursor.execute(query, (user, host))
def privileges_grant(cursor, user,host,db_table,priv): def privileges_grant(cursor, user,host,db_table,priv):
priv_string = ",".join(filter(lambda x: x != 'GRANT', priv)) priv_string = ",".join(filter(lambda x: x != 'GRANT', priv))
query = "GRANT %s ON %s TO '%s'@'%s'" % (priv_string,db_table,user,host) query = ["GRANT %s ON %s" % (priv_string, mysql_quote_identifier(db_table, 'table'))]
query.append("TO %s@%s")
if 'GRANT' in priv: if 'GRANT' in priv:
query = query + " WITH GRANT OPTION" query.append("WITH GRANT OPTION")
cursor.execute(query) query = ' '.join(query)
cursor.execute(query, (user, host))
def strip_quotes(s): def strip_quotes(s):
@ -425,8 +447,8 @@ def main():
if priv is not None: if priv is not None:
try: try:
priv = privileges_unpack(priv) priv = privileges_unpack(priv)
except: except Exception, e:
module.fail_json(msg="invalid privileges string") module.fail_json(msg="invalid privileges string: %s" % str(e))
# Either the caller passes both a username and password with which to connect to # Either the caller passes both a username and password with which to connect to
# mysql, or they pass neither and allow this module to read the credentials from # mysql, or they pass neither and allow this module to read the credentials from
@ -459,11 +481,17 @@ def main():
if state == "present": if state == "present":
if user_exists(cursor, user, host): if user_exists(cursor, user, host):
try:
changed = user_mod(cursor, user, host, password, priv, append_privs) changed = user_mod(cursor, user, host, password, priv, append_privs)
except SQLParseError, e:
module.fail_json(msg=str(e))
else: else:
if password is None: if password is None:
module.fail_json(msg="password parameter required when adding a user") module.fail_json(msg="password parameter required when adding a user")
try:
changed = user_add(cursor, user, host, password, priv) changed = user_add(cursor, user, host, password, priv)
except SQLParseError, e:
module.fail_json(msg=str(e))
elif state == "absent": elif state == "absent":
if user_exists(cursor, user, host): if user_exists(cursor, user, host):
changed = user_delete(cursor, user, host) changed = user_delete(cursor, user, host)
@ -473,4 +501,6 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
from ansible.module_utils.database import *
if __name__ == '__main__':
main() main()

View file

@ -103,7 +103,7 @@ def typedvalue(value):
def getvariable(cursor, mysqlvar): def getvariable(cursor, mysqlvar):
cursor.execute("SHOW VARIABLES LIKE '" + mysqlvar + "'") cursor.execute("SHOW VARIABLES LIKE %s", (mysqlvar,))
mysqlvar_val = cursor.fetchall() mysqlvar_val = cursor.fetchall()
return mysqlvar_val return mysqlvar_val
@ -116,8 +116,11 @@ def setvariable(cursor, mysqlvar, value):
should be passed as numeric literals. should be passed as numeric literals.
""" """
query = ["SET GLOBAL %s" % mysql_quote_identifier(mysqlvar, 'vars') ]
query.append(" = %s")
query = ' '.join(query)
try: try:
cursor.execute("SET GLOBAL " + mysqlvar + " = %s", (value,)) cursor.execute(query, (value,))
cursor.fetchall() cursor.fetchall()
result = True result = True
except Exception, e: except Exception, e:
@ -242,7 +245,10 @@ def main():
value_actual = typedvalue(mysqlvar_val[0][1]) value_actual = typedvalue(mysqlvar_val[0][1])
if value_wanted == value_actual: if value_wanted == value_actual:
module.exit_json(msg="Variable already set to requested value", changed=False) module.exit_json(msg="Variable already set to requested value", changed=False)
try:
result = setvariable(cursor, mysqlvar, value_wanted) result = setvariable(cursor, mysqlvar, value_wanted)
except SQLParseError, e:
result = str(e)
if result is True: if result is True:
module.exit_json(msg="Variable change succeeded prev_value=%s" % value_actual, changed=True) module.exit_json(msg="Variable change succeeded prev_value=%s" % value_actual, changed=True)
else: else:
@ -250,4 +256,5 @@ def main():
# import module snippets # import module snippets
from ansible.module_utils.basic import * from ansible.module_utils.basic import *
from ansible.module_utils.database import *
main() main()