Cleanup some minor issues in PluginLoader:

* class_only was a keyword arg of get() and all() that was mistakenly
  passed on to Plugins.  Be sure to strip it from the keyword args
  before instantiating Plugins.  (Reworked API probably should either
  not instantiate Plugins or take the args for the Plugin as a separate
  list and a dict.)
* Checking required base_classes was only done in get() and only if
  class_only was False (ie: that Plugin was instantiated).  This meant
  that different plugins could be found depending on whether the call
  was to .get() or to all() and whether it was for classes or instances.
  Fixed so that required base_classes are always checked.
This commit is contained in:
Toshio Kuratomi 2016-03-21 08:39:57 -07:00
parent 407f8f934e
commit 7ce130212f

View file

@ -316,6 +316,7 @@ class PluginLoader:
def get(self, name, *args, **kwargs): def get(self, name, *args, **kwargs):
''' instantiates a plugin of the given name using arguments ''' ''' instantiates a plugin of the given name using arguments '''
class_only = kwargs.pop('class_only', False)
if name in self.aliases: if name in self.aliases:
name = self.aliases[name] name = self.aliases[name]
path = self.find_plugin(name) path = self.find_plugin(name)
@ -325,23 +326,28 @@ class PluginLoader:
if path not in self._module_cache: if path not in self._module_cache:
self._module_cache[path] = self._load_module_source('.'.join([self.package, name]), path) self._module_cache[path] = self._load_module_source('.'.join([self.package, name]), path)
if kwargs.get('class_only', False):
obj = getattr(self._module_cache[path], self.class_name) obj = getattr(self._module_cache[path], self.class_name)
else:
obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs)
if self.base_class: if self.base_class:
# The import path is hardcoded and should be the right place, # The import path is hardcoded and should be the right place,
# so we are not expecting an ImportError. # so we are not expecting an ImportError.
module = __import__(self.package, fromlist=[self.base_class]) module = __import__(self.package, fromlist=[self.base_class])
# Check whether this obj has the required base class. # Check whether this obj has the required base class.
if not issubclass(obj.__class__, getattr(module, self.base_class, None)): try:
plugin_class = getattr(module, self.base_class)
except AttributeError:
return None return None
if not issubclass(obj, plugin_class):
return None
if not class_only:
obj = obj(*args, **kwargs)
return obj return obj
def all(self, *args, **kwargs): def all(self, *args, **kwargs):
''' instantiates all plugins with the same arguments ''' ''' instantiates all plugins with the same arguments '''
class_only = kwargs.pop('class_only', False)
for i in self._get_paths(): for i in self._get_paths():
matches = glob.glob(os.path.join(i, "*.py")) matches = glob.glob(os.path.join(i, "*.py"))
matches.sort() matches.sort()
@ -353,13 +359,21 @@ class PluginLoader:
if path not in self._module_cache: if path not in self._module_cache:
self._module_cache[path] = self._load_module_source(name, path) self._module_cache[path] = self._load_module_source(name, path)
if kwargs.get('class_only', False):
obj = getattr(self._module_cache[path], self.class_name) obj = getattr(self._module_cache[path], self.class_name)
else: if self.base_class:
obj = getattr(self._module_cache[path], self.class_name)(*args, **kwargs) # The import path is hardcoded and should be the right place,
# so we are not expecting an ImportError.
if self.base_class and self.base_class not in [base.__name__ for base in obj.__class__.__bases__]: module = __import__(self.package, fromlist=[self.base_class])
# Check whether this obj has the required base class.
try:
plugin_class = getattr(module, self.base_class)
except AttributeError:
continue continue
if not issubclass(obj, plugin_class):
continue
if not class_only:
obj = obj(*args, **kwargs)
# set extra info on the module, in case we want it later # set extra info on the module, in case we want it later
setattr(obj, '_original_path', path) setattr(obj, '_original_path', path)