NeMo/nemo/core/classes/modelPT.py
Oleksii Kuchaiev d1abb897c0
Docs, Contributing, Readme draft (#977)
* Docs, Contributing, Readme draft

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* minor readme update

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* update

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* docs non-existing link

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* update link

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* add readthedocs.yaml

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
2020-08-03 17:18:44 -07:00

668 lines
28 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import shutil
import tarfile
import tempfile
from abc import abstractmethod
from os import path
from typing import Dict, List, Optional, Union
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule, Trainer
from nemo.core import optim
from nemo.core.classes.common import Model
from nemo.core.optim import prepare_lr_scheduler
from nemo.utils import logging, model_utils
__all__ = ['ModelPT']
_MODEL_CONFIG_YAML = "model_config.yaml"
_MODEL_WEIGHTS = "model_weights.ckpt"
class ModelPT(LightningModule, Model):
"""
Interface for Pytorch-lightning based NeMo models
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
"""
Base class from which all NeMo models should inherit
Args:
cfg (DictConfig): configuration object.
The cfg object should have (optionally) the following sub-configs:
* train_ds - to instantiate training dataset
* validation_ds - to instantiate validation dataset
* test_ds - to instantiate testing dataset
* optim - to instantiate optimizer with learning rate scheduler
trainer (Optional): Pytorch Lightning Trainer instance
"""
if not isinstance(cfg, DictConfig):
raise ValueError(f"cfg constructor argument must be of type DictConfig but got {type(cfg)} instead.")
if trainer is not None and not isinstance(trainer, Trainer):
raise ValueError(
f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead."
)
super().__init__()
if 'target' not in cfg:
# This is for Jarvis service.
OmegaConf.set_struct(cfg, False)
cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__)
OmegaConf.set_struct(cfg, True)
self._cfg = cfg
self.save_hyperparameters(self._cfg)
self._train_dl = None
self._validation_dl = None
self._test_dl = None
self._optimizer = None
self._scheduler = None
self._trainer = trainer
if self._cfg is not None: # TODO: This check is redundant since we know cfg is an instance of DictConfig
if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
self.setup_training_data(self._cfg.train_ds)
if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
# Set some placeholder overriden by helper method
self._validation_loss_idx = 0
self._validation_names = None
self._validation_dl = None # type: torch.utils.data.DataLoader
model_utils.resolve_validation_dataloaders(model=self)
if self._validation_names is None:
if self._validation_dl is not None and type(self._validation_dl) in [list, tuple]:
self._validation_names = ['val_{}_'.format(idx) for idx in range(len(self._validation_dl))]
if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
# Set some placeholder overriden by helper method
self._test_loss_idx = 0
self._test_names = None
self._test_dl = None # type: torch.utils.data.DataLoader
model_utils.resolve_test_dataloaders(model=self)
if self._test_names is None:
if self._test_dl is not None and type(self._test_dl) in [list, tuple]:
self._test_names = ['test_{}_'.format(idx) for idx in range(len(self._test_dl))]
def register_artifact(self, config_path: str, src: str):
"""
Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
when model.save_to("mymodel.nemo") is called.
WARNING: If you specified /example_folder/example.txt but ./example.txt exists, then ./example.txt will be used.
Args:
config_path: config path where artifact is used
src: path to the artifact
Returns:
path to be used when accessing artifact. If src='' or None then '' or None will be returned
"""
if not hasattr(self, 'artifacts'):
self.artifacts = []
if self.artifacts is None:
self.artifacts = []
if src is not None and src.strip() != '':
basename_src = os.path.basename(src)
# filename exists in current workdir - use it and raise warning
if os.path.exists(basename_src):
logging.warning(f"Using {os.path.abspath(basename_src)} instead of {src}.")
used_src = basename_src
else:
used_src = src
if not os.path.exists(used_src):
raise FileNotFoundError(f"Could not find {used_src}")
self.artifacts.append((config_path, used_src))
return used_src
else:
return src
def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file. You can use "restore_from" method to fully
restore instance from .nemo file.
.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
Args:
save_path: Path to .nemo file where model instance should be saved
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML)
model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
if hasattr(self, 'artifacts') and self.artifacts is not None:
for (conf_path, src) in self.artifacts:
try:
if os.path.exists(src):
shutil.copy2(src, tmpdir)
except Exception:
logging.error(f"Could not copy artifact {src} used in {conf_path}")
self.to_config_file(path2yaml_file=config_yaml)
torch.save(self.state_dict(), model_weights)
self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
@classmethod
def restore_from(cls, restore_path: str):
"""
Restores model instance (weights and configuration) into .nemo file
Args:
restore_path: path to .nemo file from which model should be instantiated
Example:
```
model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
```
Returns:
An instance of type cls
"""
if not path.exists(restore_path):
raise FileExistsError(f"Can't find {restore_path}")
with tempfile.TemporaryDirectory() as tmpdir:
cls.__unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML)
model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
conf = OmegaConf.load(config_yaml)
OmegaConf.set_struct(conf, True)
instance = cls.from_config_dict(config=conf)
instance.load_state_dict(torch.load(model_weights))
return instance
@abstractmethod
def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
"""
Setups data loader to be used in training
Args:
train_data_layer_config: training data layer parameters.
Returns:
"""
pass
@abstractmethod
def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
"""
(Optionally) Setups data loader to be used in validation
Args:
val_data_layer_config: validation data layer parameters.
Returns:
"""
pass
def setup_test_data(self, test_data_config: Union[DictConfig, Dict]):
"""
(Optionally) Setups data loader to be used in test
Args:
test_data_layer_config: test data layer parameters.
Returns:
"""
raise NotImplementedError()
def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = None):
"""
Prepares an optimizer from a string name and its optional config parameters.
Args:
optim_config: A dictionary containing the following keys:
* "lr": mandatory key for learning rate. Will raise ValueError if not provided.
* "optimizer": string name pointing to one of the available optimizers in the registry. \
If not provided, defaults to "adam".
* "opt_args": Optional list of strings, in the format "arg_name=arg_value". \
The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \
will be built and supplied to instantiate the optimizer.
"""
# If config was not explicitly passed to us
if optim_config is None:
# See if internal config has `optim` namespace
if self._cfg is not None and hasattr(self._cfg, 'optim'):
optim_config = self._cfg.optim
# If config is still None, or internal config has no Optim, return without instantiation
if optim_config is None:
logging.info('No optimizer config provided, therefore no optimizer was created')
return
# Setup optimizer and scheduler
if optim_config is not None and isinstance(optim_config, DictConfig):
optim_config = OmegaConf.to_container(optim_config)
if 'sched' in optim_config and self._trainer is not None:
if not isinstance(self._trainer.accumulate_grad_batches, int):
raise ValueError("We do not currently support gradient acculumation that is not an integer.")
if self._trainer.max_steps is None:
if self._trainer.num_gpus == 0:
# training on CPU
iters_per_batch = self._trainer.max_epochs / float(
self._trainer.num_nodes * self._trainer.accumulate_grad_batches
)
else:
iters_per_batch = self._trainer.max_epochs / float(
self._trainer.num_gpus * self._trainer.num_nodes * self._trainer.accumulate_grad_batches
)
optim_config['sched']['iters_per_batch'] = iters_per_batch
else:
optim_config['sched']['max_steps'] = self._trainer.max_steps
# Force into DictConfig from nested structure
optim_config = OmegaConf.create(optim_config)
# Get back nested dict so we its mutable
optim_config = OmegaConf.to_container(optim_config, resolve=True)
# Extract scheduler config if inside optimizer config
if 'sched' in optim_config:
scheduler_config = optim_config.pop('sched')
else:
scheduler_config = None
# Check if caller provided optimizer name, default to Adam otherwise
optimizer_cls = optim_config.get('cls', None)
if optimizer_cls is None:
# Try to get optimizer name for dynamic resolution, defaulting to Adam
optimizer_name = optim_config.get('name', 'adam')
else:
if inspect.isclass(optimizer_cls):
optimizer_name = optimizer_cls.__name__.lower()
else:
# resolve the class name (lowercase) from the class path if not provided
optimizer_name = optimizer_cls.split(".")[-1].lower()
# We are guarenteed to have lr since it is required by the argparser
# But maybe user forgot to pass it to this function
lr = optim_config.get('lr', None)
if 'lr' is None:
raise ValueError('`lr` must be passed to `optimizer_config` when setting up the optimization !')
# Check if caller has optimizer kwargs, default to empty dictionary
if 'args' in optim_config:
optimizer_args = optim_config.pop('args')
optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args)
else:
optimizer_args = copy.deepcopy(optim_config)
# Remove extra parameters from optimizer_args nest
# Assume all other parameters are to be passed into optimizer constructor
optimizer_args.pop('name', None)
optimizer_args.pop('cls', None)
optimizer_args.pop('lr', None)
# Actually instantiate the optimizer
if optimizer_cls is not None:
if inspect.isclass(optimizer_cls):
optimizer = optimizer_cls(self.parameters(), lr=lr, **optimizer_args)
logging.info("Optimizer config = %s", str(optimizer))
self._optimizer = optimizer
else:
# Attempt class path resolution
try:
optimizer_cls = OmegaConf.create({'cls': optimizer_cls})
optimizer_config = {'lr': lr}
optimizer_config.update(optimizer_args)
optimizer_instance = hydra.utils.instantiate(
optimizer_cls, self.parameters(), **optimizer_config
) # type: DictConfig
logging.info("Optimizer config = %s", str(optimizer_instance))
self._optimizer = optimizer_instance
except Exception as e:
logging.error(
"Could not instantiate class path - {} with kwargs {}".format(
optimizer_cls, str(optimizer_config)
)
)
raise e
else:
optimizer = optim.get_optimizer(optimizer_name)
optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)
logging.info("Optimizer config = %s", str(optimizer))
self._optimizer = optimizer
# Try to instantiate scheduler for optimizer
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl
)
# Return the optimizer with/without scheduler
# This return allows multiple optimizers or schedulers to be created
return self._optimizer, self._scheduler
def configure_optimizers(self):
self.setup_optimization()
if self._scheduler is None:
return self._optimizer
else:
return [self._optimizer], [self._scheduler]
def train_dataloader(self):
if self._train_dl is not None:
return self._train_dl
def training_step(self, batch, batch_ix):
pass
def val_dataloader(self):
if self._validation_dl is not None:
return self._validation_dl
def validation_step(self, batch, batch_ix):
pass
def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
def test_step(self, batch, batch_ix):
pass
def validation_epoch_end(
self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""
Default DataLoader for Validation set which automatically supports multiple data loaders
via `multi_validation_epoch_end`.
If multi dataset support is not required, override this method entirely in base class.
In such a case, there is no need to implement `multi_validation_epoch_end` either.
Note:
If more than one data loader exists, and they all provide `val_loss`,
only the `val_loss` of the first data loader will be used by default.
This default can be changed by passing the special key `val_loss_idx: int`
inside the `validation_ds` config.
Args:
outputs: Single or nested list of tensor outputs from one or more data loaders.
Returns:
A dictionary containing the union of all items from individual data_loaders,
along with merged logs from all data loaders.
"""
# Case where we dont provide data loaders
if outputs is not None and len(outputs) == 0:
return
# Case where we provide exactly 1 data loader
if type(outputs[0]) == dict:
return self.multi_validation_epoch_end(outputs, dataloader_idx=0)
else: # Case where we provide more than 1 data loader
output_dict = {'log': {}}
# The output is a list of list of dicts, outer list corresponds to dataloader idx
for dataloader_idx, val_outputs in enumerate(outputs):
# Get prefix and dispatch call to multi epoch end
dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx)
dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx)
# If result was not provided, generate empty dict
dataloader_logs = dataloader_logs or {}
# Perform `val_loss` resolution first (if provided outside logs)
if 'val_loss' in dataloader_logs:
if 'val_loss' not in output_dict and dataloader_idx == self._validation_loss_idx:
output_dict['val_loss'] = dataloader_logs['val_loss']
# For every item in the result dictionary
for k, v in dataloader_logs.items():
# If the key is `log`
if k == 'log':
# Parse every element of the log, and attach the prefix name of the data loader
log_dict = {}
for k_log, v_log in v.items():
# If we are logging the loss, but dont provide it at result level,
# store it twice - once in log and once in result level.
# Also mark log with prefix name to avoid log level clash with other data loaders
if (
k_log == 'val_loss'
and 'val_loss' not in output_dict['log']
and dataloader_idx == self._validation_loss_idx
):
new_k_log = 'val_loss'
# Also insert duplicate key with prefix for ease of comparison / avoid name clash
log_dict[dataloader_prefix + k_log] = v_log
elif k_log == 'val_loss' and dataloader_idx != self._validation_loss_idx:
# replace all other "val_loss" with <prefix> + loss
# this avoid duplication of the word <prefix> + "val_loss" which causes confusion
new_k_log = dataloader_prefix + 'loss'
else:
# Simply prepend prefix to key and save
new_k_log = dataloader_prefix + k_log
# Store log value
log_dict[new_k_log] = v_log
# Update log storage of individual data loader
output_logs = output_dict['log']
output_logs.update(log_dict)
# Update global log storage
output_dict['log'] = output_logs
else:
# If any values are stored outside 'log', simply prefix name and store
new_k = dataloader_prefix + k
output_dict[new_k] = v
return output_dict
def test_epoch_end(
self, outputs: Union[List[Dict[str, torch.Tensor]], List[List[Dict[str, torch.Tensor]]]]
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""
Default DataLoader for Test set which automatically supports multiple data loaders
via `multi_test_epoch_end`.
If multi dataset support is not required, override this method entirely in base class.
In such a case, there is no need to implement `multi_test_epoch_end` either.
Note:
If more than one data loader exists, and they all provide `test_loss`,
only the `test_loss` of the first data loader will be used by default.
This default can be changed by passing the special key `test_loss_idx: int`
inside the `test_ds` config.
Args:
outputs: Single or nested list of tensor outputs from one or more data loaders.
Returns:
A dictionary containing the union of all items from individual data_loaders,
along with merged logs from all data loaders.
"""
# Case where we dont provide data loaders
if outputs is not None and len(outputs) == 0:
return
# Case where we provide exactly 1 data loader
if type(outputs[0]) == dict:
return self.multi_test_epoch_end(outputs, dataloader_idx=0)
else: # Case where we provide more than 1 data loader
output_dict = {'log': {}}
# The output is a list of list of dicts, outer list corresponds to dataloader idx
for dataloader_idx, test_outputs in enumerate(outputs):
# Get prefix and dispatch call to multi epoch end
dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx)
dataloader_logs = self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx)
# If result was not provided, generate empty dict
dataloader_logs = dataloader_logs or {}
# Perform `test_loss` resolution first (if provided outside logs)
if 'test_loss' in dataloader_logs:
if 'test_loss' not in output_dict and dataloader_idx == self._validation_loss_idx:
output_dict['test_loss'] = dataloader_logs['test_loss']
# For every item in the result dictionary
for k, v in dataloader_logs.items():
# If the key is `log`
if k == 'log':
# Parse every element of the log, and attach the prefix name of the data loader
log_dict = {}
for k_log, v_log in v.items():
# If we are logging the loss, but dont provide it at result level,
# store it twice - once in log and once in result level.
# Also mark log with prefix name to avoid log level clash with other data loaders
if (
k_log == 'test_loss'
and 'test_loss' not in output_dict['log']
and dataloader_idx == self._test_loss_idx
):
new_k_log = 'test_loss'
# Also insert duplicate key with prefix for ease of comparison
log_dict[dataloader_prefix + k_log] = v_log
elif k_log == 'test_loss' and dataloader_idx != self._test_loss_idx:
# replace all other "test_loss" with <prefix> + loss
# this avoid duplication of the word <prefix> + "test_loss" which causes confusion
new_k_log = dataloader_prefix + 'loss'
else:
# Simply prepend prefix to key and save
new_k_log = dataloader_prefix + k_log
log_dict[new_k_log] = v_log
# Update log storage of individual data loader
output_logs = output_dict.get('log', {})
output_logs.update(log_dict)
# Update global log storage
output_dict['log'] = output_logs
else:
# If any values are stored outside 'log', simply prefix name and store
new_k = dataloader_prefix + k
output_dict[new_k] = v
return output_dict
def multi_validation_epoch_end(
self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
logging.warning(
"Multi data loader support has been enabled, but "
"`multi_validation_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
"If you require multi data loader support for validation sets, please override this method.\n"
"If you do not require multi data loader support, please instead override "
"`validation_epoch_end(outputs)."
)
def multi_test_epoch_end(
self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
logging.warning(
"Multi data loader support has been enabled, but "
"`multi_test_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
"If you require multi data loader support for validation sets, please override this method.\n"
"If you do not require multi data loader support, please instead override "
"`test_epoch_end(outputs)."
)
def get_validation_dataloader_prefix(self, dataloader_idx=0):
return self._validation_names[dataloader_idx]
def get_test_dataloader_prefix(self, dataloader_idx=0):
return self._test_names[dataloader_idx]
def prepare_test(self, trainer: 'Trainer') -> bool:
"""
Helper method to check whether the model can safely be tested
on a dataset after training (or loading a checkpoint).
# Usage:
trainer = Trainer()
if model.prepare_test(trainer):
trainer.test(model)
Returns:
bool which declares the model safe to test. Provides warnings if it has to
return False to guide the user.
"""
if not hasattr(self._cfg, 'test_ds'):
logging.info("No `test_ds` config found within the manifest.")
return False
# Replace ddp multi-gpu until PTL has a fix
DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer "
"with single GPU and no DDP.\n"
"Following pattern should be used: \n"
"trainer = Trainer()\n"
"if model.prepare_test(trainer):\n"
" trainer.test(model)\n\n"""
if trainer is not None:
if trainer.num_gpus > 1:
logging.warning(DDP_WARN)
return False
return True
@property
def num_weights(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@staticmethod
def __make_nemo_file_from_folder(filename, source_dir):
with tarfile.open(filename, "w:gz") as tar:
# tar.add(source_dir, arcname=path.basename(source_dir))
tar.add(source_dir, arcname="./")
@staticmethod
def __unpack_nemo_file(path2file: str, out_folder: str) -> str:
if not path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")
tar = tarfile.open(path2file, "r:gz")
tar.extractall(path=out_folder)
tar.close()
return out_folder