Support target classpath resolution for all ModelPT subclasses (#1982)

* Support target classpath resolution for all ModelPT subclasses

Signed-off-by: smajumdar <titu1994@gmail.com>

* Support target classpath resolution for all ModelPT subclasses

Signed-off-by: smajumdar <titu1994@gmail.com>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
Somshubra Majumdar 2021-03-29 19:51:15 -07:00 committed by GitHub
parent d10f188067
commit 067c1f2a04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 4 deletions

View file

@ -15,6 +15,7 @@
"""Interfaces common to all Neural Modules and Models."""
import hashlib
import traceback
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
@ -31,7 +32,7 @@ import nemo
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.utils import logging
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.model_utils import maybe_update_config_version
from nemo.utils.model_utils import import_class_by_path, maybe_update_config_version
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck']
@ -425,15 +426,43 @@ class Serialization(ABC):
config = maybe_update_config_version(config)
# Hydra 0.x API
if ('cls' in config or 'target' in config) and 'params' in config:
# regular hydra-based instantiation
instance = hydra.utils.instantiate(config=config)
# Hydra 1.x API
elif '_target_' in config:
# regular hydra-based instantiation
instance = hydra.utils.instantiate(config=config)
else:
# models are handled differently for now
instance = cls(cfg=config)
instance = None
# Attempt class path resolution from config `target` class (if it exists)
if 'target' in config:
target_cls = config.target
imported_cls = None
try:
# try to import the target class
imported_cls = import_class_by_path(target_cls)
except (ImportError, ModuleNotFoundError):
logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls')
# try instantiating model with target class
if imported_cls is not None:
try:
instance = imported_cls(cfg=config)
except Exception:
imported_cls_tb = traceback.format_exc()
logging.debug(
f"Model instantiation from target class failed with following error.\n"
f"Falling back to `cls`.\n"
f"{imported_cls_tb}"
)
instance = None
# target class resolution was unsuccessful, fall back to current `cls`
if instance is None:
instance = cls(cfg=config)
if not hasattr(instance, '_cfg'):
instance._cfg = config

View file

@ -133,7 +133,7 @@ class TestSaveRestore:
self.__test_restore_elsewhere(model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
@pytest.mark.unit
def test_EncDecCTCModelBPE(self):
def test_EncDecCTCModelBPE_v2(self):
# TODO: Switch to using named configs because here we don't really care about weights
cn = EncDecCTCModelBPE.from_pretrained(model_name="stt_en_conformer_ctc_small")
self.__test_restore_elsewhere(model=cn, attr_for_eq_check=set(["decoder._feat_in", "decoder._num_classes"]))
@ -271,3 +271,36 @@ class TestSaveRestore:
# Test that new config has arbitrary content
assert model_copy.cfg.xyz == "abc"
@pytest.mark.unit
def test_mock_save_to_restore_from_with_target_class(self):
with tempfile.NamedTemporaryFile('w') as empty_file:
# Write some data
empty_file.writelines(["*****\n"])
empty_file.flush()
# Update config
cfg = _mock_model_config()
cfg.model.temp_file = empty_file.name
# Create model
model = MockModel(cfg=cfg.model, trainer=None)
model = model.to('cpu') # type: MockModel
assert model.temp_file == empty_file.name
# Save file using MockModel
with tempfile.TemporaryDirectory() as save_folder:
save_path = os.path.join(save_folder, "temp.nemo")
model.save_to(save_path)
# Restore test (using ModelPT as restorer)
# This forces the target class = MockModel to be used as resolver
model_copy = ModelPT.restore_from(save_path, map_location='cpu')
# Restore test
diff = model.w.weight - model_copy.w.weight
assert diff.mean() <= 1e-9
assert isinstance(model_copy, MockModel)
assert os.path.basename(model.temp_file) == model_copy.temp_file
assert model_copy.temp_data == ["*****\n"]