Timer class monitors total time (train + validation + testing) to monitor when to end training (#3061)
* Check total time in train/validation to exit Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> * Style fixes Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca> Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
This commit is contained in:
parent
a1c15e72fc
commit
eb9117dcd7
|
@ -33,6 +33,8 @@ from pytorch_lightning.callbacks.timer import Interval, Timer
|
|||
from pytorch_lightning.loggers import LoggerCollection as _LoggerCollection
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
from pytorch_lightning.utilities.types import _METRIC
|
||||
|
||||
from nemo.constants import NEMO_ENV_VARNAME_VERSION
|
||||
|
@ -753,6 +755,9 @@ class NeMoModelCheckpoint(ModelCheckpoint):
|
|||
if trainer.fast_dev_run:
|
||||
return None
|
||||
|
||||
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates=monitor_candidates)
|
||||
|
||||
# Load the best model and then re-save it
|
||||
if self.save_best_model:
|
||||
trainer.checkpoint_connector.restore(self.best_model_path)
|
||||
|
@ -934,3 +939,19 @@ class StatelessTimer(Timer):
|
|||
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state) -> None:
|
||||
return
|
||||
|
||||
def _check_time_remaining(self, trainer) -> None:
|
||||
# Default timer only checks for train time exceeding max_time, this includes time for all stages.
|
||||
train_duration = self.time_elapsed(RunningStage.TRAINING)
|
||||
validation_duration = self.time_elapsed(RunningStage.VALIDATING)
|
||||
test_duration = self.time_elapsed(RunningStage.TESTING)
|
||||
total_duration = train_duration + validation_duration + test_duration
|
||||
should_stop = total_duration >= self._duration
|
||||
# should_stop = trainer.training_type_plugin.broadcast(should_stop)
|
||||
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
|
||||
trainer.should_stop = trainer.should_stop or should_stop
|
||||
if should_stop and self._verbose:
|
||||
rank_zero_info(f"Time limit reached. Signaling Trainer to stop.")
|
||||
rank_zero_info(
|
||||
f"Spent {timedelta(seconds=train_duration)} seconds on training, {timedelta(seconds=validation_duration)} seconds on validation and {timedelta(seconds=test_duration)} seconds on testing"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue