New get_all_subclasses function in basic modules and use it in fact modules.

This commit is contained in:
Yannig Perre 2016-04-23 09:14:20 +02:00 committed by Toshio Kuratomi
parent eb18767f91
commit 72f17f3ff3
2 changed files with 28 additions and 21 deletions

View file

@ -303,6 +303,26 @@ def get_distribution_version():
distribution_version = None
return distribution_version
def get_all_subclasses(cls):
'''
used by modules like Hardware or Network fact classes to retrieve all subclasses of a given class.
__subclasses__ return only direct sub classes. This one go down into the class tree.
'''
# Retrieve direct subclasses
subclasses = cls.__subclasses__()
to_visit = list(subclasses)
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
subclasses.append(ssc)
to_visit.append(ssc)
return subclasses
def load_platform_subclass(cls, *args, **kwargs):
'''
used by modules like User to have different implementations based on detected platform. See User
@ -315,11 +335,11 @@ def load_platform_subclass(cls, *args, **kwargs):
# get the most specific superclass for this platform
if distribution is not None:
for sc in cls.__subclasses__():
for sc in get_all_subclasses(cls):
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform:
subclass = sc
if subclass is None:
for sc in cls.__subclasses__():
for sc in get_all_subclasses(cls):
if sc.platform == this_platform and sc.distribution is None:
subclass = sc
if subclass is None:

View file

@ -32,6 +32,7 @@ import datetime
import getpass
import pwd
import ConfigParser
from basic import get_all_subclasses
# py2 vs py3; replace with six via ziploader
try:
@ -867,7 +868,7 @@ class Hardware(Facts):
def __new__(cls, *arguments, **keyword):
subclass = cls
for sc in Hardware.__subclasses__():
for sc in get_all_subclasses(Hardware):
if sc.platform == platform.system():
subclass = sc
return super(cls, subclass).__new__(subclass, *arguments, **keyword)
@ -1949,23 +1950,9 @@ class Network(Facts):
def __new__(cls, *arguments, **keyword):
subclass = cls
# Retrieve direct subclasses
to_visit = Network.__subclasses__()
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# Check if current class is the good one
if sc.platform == platform.system():
subclass = sc
to_visit = []
break
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
to_visit.append(ssc)
# Now, return corresponding subclass
for sc in get_all_subclasses(Network):
if sc.platform == platform.system():
subclass = sc
return super(cls, subclass).__new__(subclass, *arguments, **keyword)
def populate(self):
@ -2725,7 +2712,7 @@ class Virtual(Facts):
def __new__(cls, *arguments, **keyword):
subclass = cls
for sc in Virtual.__subclasses__():
for sc in get_all_subclasses(Virtual):
if sc.platform == platform.system():
subclass = sc
return super(cls, subclass).__new__(subclass, *arguments, **keyword)