From 32aaa07325c74a5394c10329ab5fe5b3494e9eba Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Mon, 24 Nov 2014 20:51:27 -0800 Subject: [PATCH] More robust quoting of database identifiers Note: These aren't database values, those are already using the appropriate Pyhton DB API method for quoting. --- .../database/postgresql/postgresql_db.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/lib/ansible/modules/database/postgresql/postgresql_db.py b/lib/ansible/modules/database/postgresql/postgresql_db.py index 605be621601..f965eac211a 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_db.py +++ b/lib/ansible/modules/database/postgresql/postgresql_db.py @@ -124,7 +124,9 @@ class NotSupportedError(Exception): # def set_owner(cursor, db, owner): - query = "ALTER DATABASE \"%s\" OWNER TO \"%s\"" % (db, owner) + query = "ALTER DATABASE %s OWNER TO %s" % ( + pg_quote_identifier(db, 'database'), + pg_quote_identifier(owner, 'role')) cursor.execute(query) return True @@ -141,7 +143,7 @@ def get_db_info(cursor, db): FROM pg_database JOIN pg_roles ON pg_roles.oid = pg_database.datdba WHERE datname = %(db)s """ - cursor.execute(query, {'db':db}) + cursor.execute(query, {'db': db}) return cursor.fetchone() def db_exists(cursor, db): @@ -151,28 +153,28 @@ def db_exists(cursor, db): def db_delete(cursor, db): if db_exists(cursor, db): - query = "DROP DATABASE \"%s\"" % db + query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database') cursor.execute(query) return True else: return False def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype): + params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype) if not db_exists(cursor, db): + query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')] if owner: - owner = " OWNER \"%s\"" % owner + query_fragments.append('OWNER %s' % pg_quote_identifier(owner, 'role')) if template: - template = " TEMPLATE \"%s\"" % template + query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database')) if encoding: - encoding = " ENCODING '%s'" % encoding + query_fragments.append('ENCODING %(enc)s') if lc_collate: - lc_collate = " LC_COLLATE '%s'" % lc_collate + query_fragments.append('LC_COLLATE %(collate)s') if lc_ctype: - lc_ctype = " LC_CTYPE '%s'" % lc_ctype - query = 'CREATE DATABASE "%s"%s%s%s%s%s' % (db, owner, - template, encoding, - lc_collate, lc_ctype) - cursor.execute(query) + query_fragments.append('LC_CTYPE %(ctype)s') + query = ' '.join(query_fragments) + cursor.execute(query, params) return True else: db_info = get_db_info(cursor, db) @@ -284,11 +286,17 @@ def main(): module.exit_json(changed=changed,db=db) if state == "absent": - changed = db_delete(cursor, db) + try: + changed = db_delete(cursor, db) + except SQLParseError, e: + module.fail_json(msg=str(e)) elif state == "present": - changed = db_create(cursor, db, owner, template, encoding, + try: + changed = db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype) + except SQLParseError, e: + module.fail_json(msg=str(e)) except NotSupportedError, e: module.fail_json(msg=str(e)) except Exception, e: @@ -298,4 +306,6 @@ def main(): # import module snippets from ansible.module_utils.basic import * -main() +from ansible.module_utils.database import * +if __name__ == '__main__': + main()