Add a stateless timer to specify max_time per run instead of global m… (#3056)

* Add a stateless timer to specify max_time per run instead of global max_time across runs

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
This commit is contained in:
Sandeep Subramanian 2021-10-26 10:05:46 -06:00 committed by GitHub
parent e05712fea2
commit 0d0de9b367
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View file

@ -16,6 +16,7 @@ from pathlib import Path
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
@ -28,7 +29,7 @@ from nemo.collections.nlp.parts.nlp_overrides import (
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.exp_manager import StatelessTimer, exp_manager
@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
@ -69,6 +70,10 @@ def main(cfg) -> None:
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')
trainer.checkpoint_connector = NLPCheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)
# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)
model = MegatronGPTModel(cfg.model, trainer)

View file

@ -19,6 +19,7 @@ import sys
import time
from copy import deepcopy
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from shutil import copy, move
from typing import Any, Dict, List, Optional, Union
@ -28,6 +29,7 @@ from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.timer import Interval, Timer
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@ -919,3 +921,16 @@ def check_slurm(trainer):
return trainer.accelerator_connector.is_slurm_managing_tasks
except AttributeError:
return False
class StatelessTimer(Timer):
"""Extension of PTL timers to be per run."""
def __init__(self, duration: timedelta = None, interval: str = Interval.step, verbose: bool = True,) -> None:
super().__init__(duration, interval, verbose)
def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> Dict[str, Any]:
return
def on_load_checkpoint(self, trainer, pl_module, callback_state) -> None:
return