DeepLearningExamples/PyTorch/Segmentation/nnUNet/utils/scheduler.py

18 lines
719 B
Python

import math
from torch.optim.lr_scheduler import LambdaLR
class WarmupCosineSchedule(LambdaLR):
def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.cycles = cycles
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
return float(step) / float(max(1.0, self.warmup_steps))
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))