Support target classpath resolution for all ModelPT subclasses (#1982)
* Support target classpath resolution for all ModelPT subclasses Signed-off-by: smajumdar <titu1994@gmail.com> * Support target classpath resolution for all ModelPT subclasses Signed-off-by: smajumdar <titu1994@gmail.com> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
parent
d10f188067
commit
067c1f2a04
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Interfaces common to all Neural Modules and Models."""
|
||||
import hashlib
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -31,7 +32,7 @@ import nemo
|
|||
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.cloud import maybe_download_from_cloud
|
||||
from nemo.utils.model_utils import maybe_update_config_version
|
||||
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version
|
||||
|
||||
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck']
|
||||
|
||||
|
@ -425,15 +426,43 @@ class Serialization(ABC):
|
|||
|
||||
config = maybe_update_config_version(config)
|
||||
|
||||
# Hydra 0.x API
|
||||
if ('cls' in config or 'target' in config) and 'params' in config:
|
||||
# regular hydra-based instantiation
|
||||
instance = hydra.utils.instantiate(config=config)
|
||||
# Hydra 1.x API
|
||||
elif '_target_' in config:
|
||||
# regular hydra-based instantiation
|
||||
instance = hydra.utils.instantiate(config=config)
|
||||
else:
|
||||
# models are handled differently for now
|
||||
instance = cls(cfg=config)
|
||||
instance = None
|
||||
|
||||
# Attempt class path resolution from config `target` class (if it exists)
|
||||
if 'target' in config:
|
||||
target_cls = config.target
|
||||
imported_cls = None
|
||||
try:
|
||||
# try to import the target class
|
||||
imported_cls = import_class_by_path(target_cls)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls')
|
||||
|
||||
# try instantiating model with target class
|
||||
if imported_cls is not None:
|
||||
try:
|
||||
instance = imported_cls(cfg=config)
|
||||
except Exception:
|
||||
imported_cls_tb = traceback.format_exc()
|
||||
logging.debug(
|
||||
f"Model instantiation from target class failed with following error.\n"
|
||||
f"Falling back to `cls`.\n"
|
||||
f"{imported_cls_tb}"
|
||||
)
|
||||
instance = None
|
||||
|
||||
# target class resolution was unsuccessful, fall back to current `cls`
|
||||
if instance is None:
|
||||
instance = cls(cfg=config)
|
||||
|
||||
if not hasattr(instance, '_cfg'):
|
||||
instance._cfg = config
|
||||
|
|
|
@ -133,7 +133,7 @@ class TestSaveRestore:
|
|||
self.__test_restore_elsewhere(model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_EncDecCTCModelBPE(self):
|
||||
def test_EncDecCTCModelBPE_v2(self):
|
||||
# TODO: Switch to using named configs because here we don't really care about weights
|
||||
cn = EncDecCTCModelBPE.from_pretrained(model_name="stt_en_conformer_ctc_small")
|
||||
self.__test_restore_elsewhere(model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
|
||||
|
@ -271,3 +271,36 @@ class TestSaveRestore:
|
|||
|
||||
# Test that new config has arbitrary content
|
||||
assert model_copy.cfg.xyz == "abc"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mock_save_to_restore_from_with_target_class(self):
|
||||
with tempfile.NamedTemporaryFile('w') as empty_file:
|
||||
# Write some data
|
||||
empty_file.writelines(["*****\n"])
|
||||
empty_file.flush()
|
||||
|
||||
# Update config
|
||||
cfg = _mock_model_config()
|
||||
cfg.model.temp_file = empty_file.name
|
||||
|
||||
# Create model
|
||||
model = MockModel(cfg=cfg.model, trainer=None)
|
||||
model = model.to('cpu') # type: MockModel
|
||||
|
||||
assert model.temp_file == empty_file.name
|
||||
|
||||
# Save file using MockModel
|
||||
with tempfile.TemporaryDirectory() as save_folder:
|
||||
save_path = os.path.join(save_folder, "temp.nemo")
|
||||
model.save_to(save_path)
|
||||
|
||||
# Restore test (using ModelPT as restorer)
|
||||
# This forces the target class = MockModel to be used as resolver
|
||||
model_copy = ModelPT.restore_from(save_path, map_location='cpu')
|
||||
|
||||
# Restore test
|
||||
diff = model.w.weight - model_copy.w.weight
|
||||
assert diff.mean() <= 1e-9
|
||||
assert isinstance(model_copy, MockModel)
|
||||
assert os.path.basename(model.temp_file) == model_copy.temp_file
|
||||
assert model_copy.temp_data == ["*****\n"]
|
||||
|
|
Loading…
Reference in a new issue