Timer class monitors total time (train + validation + testing) to monitor when to end training (#3061)

* Check total time in train/validation to exit

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
Sandeep Subramanian 2021-10-26 18:55:06 -06:00 committed by GitHub
parent a1c15e72fc
commit eb9117dcd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,6 +33,8 @@ from pytorch_lightning.callbacks.timer import Interval, Timer
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.types import _METRIC
from nemo.constants import NEMO_ENV_VARNAME_VERSION
@ -753,6 +755,9 @@ class NeMoModelCheckpoint(ModelCheckpoint):
if trainer.fast_dev_run:
return None
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
self._save_last_checkpoint(trainer, monitor_candidates=monitor_candidates)
# Load the best model and then re-save it
if self.save_best_model:
trainer.checkpoint_connector.restore(self.best_model_path)
@ -934,3 +939,19 @@ class StatelessTimer(Timer):
def on_load_checkpoint(self, trainer, pl_module, callback_state) -> None:
return
def _check_time_remaining(self, trainer) -> None:
# Default timer only checks for train time exceeding max_time, this includes time for all stages.
train_duration = self.time_elapsed(RunningStage.TRAINING)
validation_duration = self.time_elapsed(RunningStage.VALIDATING)
test_duration = self.time_elapsed(RunningStage.TESTING)
total_duration = train_duration + validation_duration + test_duration
should_stop = total_duration >= self._duration
# should_stop = trainer.training_type_plugin.broadcast(should_stop)
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
rank_zero_info(f"Time limit reached. Signaling Trainer to stop.")
rank_zero_info(
f"Spent {timedelta(seconds=train_duration)} seconds on training, {timedelta(seconds=validation_duration)} seconds on validation and {timedelta(seconds=test_duration)} seconds on testing"
)