Normalize privs and flags to uppercase so comparisons against allowed names will work

This commit is contained in:
Toshio Kuratomi 2014-11-25 00:44:18 -08:00 committed by Matt Clay
parent e0ac340f59
commit 3e9771f544
2 changed files with 6 additions and 6 deletions

View file

@ -563,7 +563,7 @@ def main():
try: try:
# privs # privs
if p.privs: if p.privs:
privs = frozenset(p.privs.split(',')) privs = frozenset(pr.upper() for pr in p.privs.split(','))
if not privs.issubset(VALID_PRIVS): if not privs.issubset(VALID_PRIVS):
module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS)) module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS))
else: else:

View file

@ -155,7 +155,7 @@ else:
postgresqldb_found = True postgresqldb_found = True
_flags = ('SUPERUSER', 'CREATEROLE', 'CREATEUSER', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION') _flags = ('SUPERUSER', 'CREATEROLE', 'CREATEUSER', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION')
VALID_FLAGS = frozenset(itertools.chain(_flags, ('NO%s' %f for f in _flags))) VALID_FLAGS = frozenset(itertools.chain(_flags, ('NO%s' % f for f in _flags)))
VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')), VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')),
database=frozenset(('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')), database=frozenset(('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')),
@ -399,9 +399,9 @@ def parse_role_attrs(role_attr_flags):
""" """
if ',' in role_attr_flags: if ',' in role_attr_flags:
flag_set = frozenset(role_attr_flags.split(",")) flag_set = frozenset(r.upper() for r in role_attr_flags.split(","))
else: else:
flag_set = frozenset(role_attr_flags) flag_set = frozenset(role_attr_flags.upper())
if not flag_set.is_subset(VALID_FLAGS): if not flag_set.is_subset(VALID_FLAGS):
raise InvalidFlagsError('Invalid role_attr_flags specified: %s' % raise InvalidFlagsError('Invalid role_attr_flags specified: %s' %
' '.join(flag_set.difference(VALID_FLAGS))) ' '.join(flag_set.difference(VALID_FLAGS)))
@ -431,11 +431,11 @@ def parse_privs(privs, db):
if ':' not in token: if ':' not in token:
type_ = 'database' type_ = 'database'
name = db name = db
priv_set = frozenset(x.strip() for x in token.split(',')) priv_set = frozenset(x.strip().upper() for x in token.split(','))
else: else:
type_ = 'table' type_ = 'table'
name, privileges = token.split(':', 1) name, privileges = token.split(':', 1)
priv_set = frozenset(x.strip() for x in privileges.split(',')) priv_set = frozenset(x.strip().upper() for x in privileges.split(','))
if not priv_set.issubset(VALID_PRIVS[type_]): if not priv_set.issubset(VALID_PRIVS[type_]):
raise InvalidPrivsError('Invalid privs specified for %s: %s' % raise InvalidPrivsError('Invalid privs specified for %s: %s' %