18 lines
719 B
Python
18 lines
719 B
Python
import math
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
|
|
class WarmupCosineSchedule(LambdaLR):
|
|
def __init__(self, optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):
|
|
self.warmup_steps = warmup_steps
|
|
self.t_total = t_total
|
|
self.cycles = cycles
|
|
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
|
|
|
def lr_lambda(self, step):
|
|
if step < self.warmup_steps:
|
|
return float(step) / float(max(1.0, self.warmup_steps))
|
|
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
|
|
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|