Add a stateless timer to specify max_time per run instead of global m… (#3056)
* Add a stateless timer to specify max_time per run instead of global max_time across runs Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
This commit is contained in:
parent
e05712fea2
commit
0d0de9b367
|
@ -16,6 +16,7 @@ from pathlib import Path
|
|||
|
||||
from omegaconf.omegaconf import OmegaConf
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
|
||||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
|
||||
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
|
||||
|
@ -28,7 +29,7 @@ from nemo.collections.nlp.parts.nlp_overrides import (
|
|||
)
|
||||
from nemo.core.config import hydra_runner
|
||||
from nemo.utils import logging
|
||||
from nemo.utils.exp_manager import exp_manager
|
||||
from nemo.utils.exp_manager import StatelessTimer, exp_manager
|
||||
|
||||
|
||||
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
|
||||
|
@ -69,6 +70,10 @@ def main(cfg) -> None:
|
|||
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
|
||||
|
||||
trainer.checkpoint_connector = NLPCheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
|
||||
# Override timer callback to a stateless one
|
||||
for idx, callback in enumerate(trainer.callbacks):
|
||||
if isinstance(callback, Timer):
|
||||
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
|
||||
|
||||
model = MegatronGPTModel(cfg.model, trainer)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import sys
|
|||
import time
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from shutil import copy, move
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
@ -28,6 +29,7 @@ from hydra.core.hydra_config import HydraConfig
|
|||
from hydra.utils import get_original_cwd
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.timer import Interval, Timer
|
||||
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
|
@ -919,3 +921,16 @@ def check_slurm(trainer):
|
|||
return trainer.accelerator_connector.is_slurm_managing_tasks
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
class StatelessTimer(Timer):
|
||||
"""Extension of PTL timers to be per run."""
|
||||
|
||||
def __init__(self, duration: timedelta = None, interval: str = Interval.step, verbose: bool = True,) -> None:
|
||||
super().__init__(duration, interval, verbose)
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> Dict[str, Any]:
|
||||
return
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state) -> None:
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue