More robust quoting of database identifiers

Note: These aren't database values, those are already using the
appropriate Pyhton DB API method for quoting.
This commit is contained in:
Toshio Kuratomi 2014-11-24 20:51:27 -08:00 committed by Matt Clay
parent f7fafa8c16
commit 32aaa07325

View file

@ -124,7 +124,9 @@ class NotSupportedError(Exception):
#
def set_owner(cursor, db, owner):
query = "ALTER DATABASE \"%s\" OWNER TO \"%s\"" % (db, owner)
query = "ALTER DATABASE %s OWNER TO %s" % (
pg_quote_identifier(db, 'database'),
pg_quote_identifier(owner, 'role'))
cursor.execute(query)
return True
@ -141,7 +143,7 @@ def get_db_info(cursor, db):
FROM pg_database JOIN pg_roles ON pg_roles.oid = pg_database.datdba
WHERE datname = %(db)s
"""
cursor.execute(query, {'db':db})
cursor.execute(query, {'db': db})
return cursor.fetchone()
def db_exists(cursor, db):
@ -151,28 +153,28 @@ def db_exists(cursor, db):
def db_delete(cursor, db):
if db_exists(cursor, db):
query = "DROP DATABASE \"%s\"" % db
query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database')
cursor.execute(query)
return True
else:
return False
def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype):
params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype)
if not db_exists(cursor, db):
query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')]
if owner:
owner = " OWNER \"%s\"" % owner
query_fragments.append('OWNER %s' % pg_quote_identifier(owner, 'role'))
if template:
template = " TEMPLATE \"%s\"" % template
query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database'))
if encoding:
encoding = " ENCODING '%s'" % encoding
query_fragments.append('ENCODING %(enc)s')
if lc_collate:
lc_collate = " LC_COLLATE '%s'" % lc_collate
query_fragments.append('LC_COLLATE %(collate)s')
if lc_ctype:
lc_ctype = " LC_CTYPE '%s'" % lc_ctype
query = 'CREATE DATABASE "%s"%s%s%s%s%s' % (db, owner,
template, encoding,
lc_collate, lc_ctype)
cursor.execute(query)
query_fragments.append('LC_CTYPE %(ctype)s')
query = ' '.join(query_fragments)
cursor.execute(query, params)
return True
else:
db_info = get_db_info(cursor, db)
@ -284,11 +286,17 @@ def main():
module.exit_json(changed=changed,db=db)
if state == "absent":
changed = db_delete(cursor, db)
try:
changed = db_delete(cursor, db)
except SQLParseError, e:
module.fail_json(msg=str(e))
elif state == "present":
changed = db_create(cursor, db, owner, template, encoding,
try:
changed = db_create(cursor, db, owner, template, encoding,
lc_collate, lc_ctype)
except SQLParseError, e:
module.fail_json(msg=str(e))
except NotSupportedError, e:
module.fail_json(msg=str(e))
except Exception, e:
@ -298,4 +306,6 @@ def main():
# import module snippets
from ansible.module_utils.basic import *
main()
from ansible.module_utils.database import *
if __name__ == '__main__':
main()