diff --git a/lib/ansible/module_utils/database.py b/lib/ansible/module_utils/database.py new file mode 100644 index 00000000000..ca7942d0483 --- /dev/null +++ b/lib/ansible/module_utils/database.py @@ -0,0 +1,114 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2014, Toshio Kuratomi +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +class SQLParseError(Exception): + pass + +class UnclosedQuoteError(SQLParseError): + pass + +# maps a type of identifier to the maximum number of dot levels that are +# allowed to specifiy that identifier. For example, a database column can be +# specified by up to 4 levels: database.schema.table.column +_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1) + +def _find_end_quote(identifier): + accumulate = 0 + while True: + try: + quote = identifier.index('"') + except ValueError: + raise UnclosedQuoteError + accumulate = accumulate + quote + try: + next_char = identifier[quote+1] + except IndexError: + return accumulate + if next_char == '"': + try: + identifier = identifier[quote+2:] + accumulate = accumulate + 2 + except IndexError: + raise UnclosedQuoteError + else: + return accumulate + + +def _identifier_parse(identifier): + if not identifier: + raise SQLParseError('Identifier name unspecified or unquoted trailing dot') + + already_quoted = False + if identifier.startswith('"'): + already_quoted = True + try: + end_quote = _find_end_quote(identifier[1:]) + 1 + except UnclosedQuoteError: + already_quoted = False + else: + if end_quote < len(identifier) - 1: + if identifier[end_quote+1] == '.': + dot = end_quote + 1 + first_identifier = identifier[:dot] + next_identifier = identifier[dot+1:] + further_identifiers = _identifier_parse(next_identifier) + further_identifiers.insert(0, first_identifier) + else: + import q ; q.q(identifier) + raise SQLParseError('User escaped identifiers must escape extra double quotes') + else: + further_identifiers = [identifier] + + if not already_quoted: + try: + dot = identifier.index('.') + except ValueError: + identifier = identifier.replace('"', '""') + identifier = ''.join(('"', identifier, '"')) + further_identifiers = [identifier] + else: + if dot == 0 or dot >= len(identifier) - 1: + identifier = identifier.replace('"', '""') + identifier = ''.join(('"', identifier, '"')) + further_identifiers = [identifier] + else: + first_identifier = identifier[:dot] + next_identifier = identifier[dot+1:] + further_identifiers = _identifier_parse(next_identifier) + first_identifier = first_identifier.replace('"', '""') + first_identifier = ''.join(('"', first_identifier, '"')) + further_identifiers.insert(0, first_identifier) + + return further_identifiers + + +def pg_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier) + if len(identifier_fragments) > _IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type])) + return '.'.join(identifier_fragments) diff --git a/lib/ansible/modules/core b/lib/ansible/modules/core index 19b328c4df2..1b0afb137c7 160000 --- a/lib/ansible/modules/core +++ b/lib/ansible/modules/core @@ -1 +1 @@ -Subproject commit 19b328c4df2157b6c0191e9144236643ce2be890 +Subproject commit 1b0afb137c78383c47b3aaa31f4b849ddcb8783f diff --git a/test/units/TestModuleUtilsDatabase.py b/test/units/TestModuleUtilsDatabase.py new file mode 100644 index 00000000000..635eadb42c4 --- /dev/null +++ b/test/units/TestModuleUtilsDatabase.py @@ -0,0 +1,103 @@ +import collections +import mock +import os + +from nose import tools + +from ansible.module_utils.database import ( + pg_quote_identifier, + SQLParseError, +) + + +# Note: Using nose's generator test cases here so we can't inherit from +# unittest.TestCase +class TestQuotePgIdentifier(object): + + # These are all valid strings + # The results are based on interpreting the identifier as a table name + valid = { + # User quoted + '"public.table"': '"public.table"', + '"public"."table"': '"public"."table"', + '"schema test"."table test"': '"schema test"."table test"', + + # We quote part + 'public.table': '"public"."table"', + '"public".table': '"public"."table"', + 'public."table"': '"public"."table"', + 'schema test.table test': '"schema test"."table test"', + '"schema test".table test': '"schema test"."table test"', + 'schema test."table test"': '"schema test"."table test"', + + # Embedded double quotes + 'table "test"': '"table ""test"""', + 'public."table ""test"""': '"public"."table ""test"""', + 'public.table "test"': '"public"."table ""test"""', + 'schema "test".table': '"schema ""test"""."table"', + '"schema ""test""".table': '"schema ""test"""."table"', + '"""wat"""."""test"""': '"""wat"""."""test"""', + # Sigh, handle these as well: + '"no end quote': '"""no end quote"', + 'schema."table': '"schema"."""table"', + '"schema.table': '"""schema"."table"', + 'schema."table.something': '"schema"."""table"."something"', + + # Embedded dots + '"schema.test"."table.test"': '"schema.test"."table.test"', + '"schema.".table': '"schema."."table"', + '"schema."."table"': '"schema."."table"', + 'schema.".table"': '"schema".".table"', + '"schema".".table"': '"schema".".table"', + '"schema.".".table"': '"schema.".".table"', + # These are valid but maybe not what the user intended + '."table"': '".""table"""', + 'table.': '"table."', + } + + invalid = { + ('test.too.many.dots', 'table'): 'PostgreSQL does not support table with more than 3 dots', + ('"test.too".many.dots', 'database'): 'PostgreSQL does not support database with more than 1 dots', + ('test.too."many.dots"', 'database'): 'PostgreSQL does not support database with more than 1 dots', + ('"test"."too"."many"."dots"', 'database'): "PostgreSQL does not support database with more than 1 dots", + ('"test"."too"."many"."dots"', 'schema'): "PostgreSQL does not support schema with more than 2 dots", + ('"test"."too"."many"."dots"', 'table'): "PostgreSQL does not support table with more than 3 dots", + ('"test"."too"."many"."dots"."for"."column"', 'column'): "PostgreSQL does not support column with more than 4 dots", + ('"table "invalid" double quote"', 'table'): 'User escaped identifiers must escape extra double quotes', + ('"schema "invalid"""."table "invalid"', 'table'): 'User escaped identifiers must escape extra double quotes', + ('"schema."table"','table'): 'User escaped identifiers must escape extra double quotes', + ('"schema".', 'table'): 'Identifier name unspecified or unquoted trailing dot', + } + + def check_valid_quotes(self, identifier, quoted_identifier): + tools.eq_(pg_quote_identifier(identifier, 'table'), quoted_identifier) + + def test_valid_quotes(self): + for identifier in self.valid: + yield self.check_valid_quotes, identifier, self.valid[identifier] + + def check_invalid_quotes(self, identifier, id_type, msg): + if hasattr(tools, 'assert_raises_regexp'): + tools.assert_raises_regexp(SQLParseError, msg, pg_quote_identifier, *(identifier, id_type)) + else: + tools.assert_raises(SQLParseError, pg_quote_identifier, *(identifier, id_type)) + + def test_invalid_quotes(self): + for test in self.invalid: + yield self.check_invalid_quotes, test[0], test[1], self.invalid[test] + + def test_how_many_dots(self): + tools.eq_(pg_quote_identifier('role', 'role'), '"role"') + tools.assert_raises_regexp(SQLParseError, "PostgreSQL does not support role with more than 1 dots", pg_quote_identifier, *('role.more', 'role')) + + tools.eq_(pg_quote_identifier('db', 'database'), '"db"') + tools.assert_raises_regexp(SQLParseError, "PostgreSQL does not support database with more than 1 dots", pg_quote_identifier, *('db.more', 'database')) + + tools.eq_(pg_quote_identifier('db.schema', 'schema'), '"db"."schema"') + tools.assert_raises_regexp(SQLParseError, "PostgreSQL does not support schema with more than 2 dots", pg_quote_identifier, *('db.schema.more', 'schema')) + + tools.eq_(pg_quote_identifier('db.schema.table', 'table'), '"db"."schema"."table"') + tools.assert_raises_regexp(SQLParseError, "PostgreSQL does not support table with more than 3 dots", pg_quote_identifier, *('db.schema.table.more', 'table')) + + tools.eq_(pg_quote_identifier('db.schema.table.column', 'column'), '"db"."schema"."table"."column"') + tools.assert_raises_regexp(SQLParseError, "PostgreSQL does not support column with more than 4 dots", pg_quote_identifier, *('db.schema.table.column.more', 'column'))