161 lines
4.3 KiB
Python
161 lines
4.3 KiB
Python
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import optim
|
|
|
|
|
|
def get_optimizer(parameters, lr, args, state=None):
|
|
if args.optimizer == "sgd":
|
|
optimizer = get_sgd_optimizer(
|
|
parameters,
|
|
lr,
|
|
momentum=args.momentum,
|
|
weight_decay=args.weight_decay,
|
|
nesterov=args.nesterov,
|
|
bn_weight_decay=args.bn_weight_decay,
|
|
)
|
|
elif args.optimizer == "rmsprop":
|
|
optimizer = get_rmsprop_optimizer(
|
|
parameters,
|
|
lr,
|
|
alpha=args.rmsprop_alpha,
|
|
momentum=args.momentum,
|
|
weight_decay=args.weight_decay,
|
|
eps=args.rmsprop_eps,
|
|
bn_weight_decay=args.bn_weight_decay,
|
|
)
|
|
if not state is None:
|
|
optimizer.load_state_dict(state)
|
|
|
|
return optimizer
|
|
|
|
|
|
def get_sgd_optimizer(
|
|
parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
|
|
):
|
|
if bn_weight_decay:
|
|
print(" ! Weight decay applied to BN parameters ")
|
|
params = [v for n, v in parameters]
|
|
else:
|
|
print(" ! Weight decay NOT applied to BN parameters ")
|
|
bn_params = [v for n, v in parameters if "bn" in n]
|
|
rest_params = [v for n, v in parameters if not "bn" in n]
|
|
print(len(bn_params))
|
|
print(len(rest_params))
|
|
|
|
params = [
|
|
{"params": bn_params, "weight_decay": 0},
|
|
{"params": rest_params, "weight_decay": weight_decay},
|
|
]
|
|
|
|
optimizer = torch.optim.SGD(
|
|
params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
|
|
)
|
|
|
|
return optimizer
|
|
|
|
|
|
def get_rmsprop_optimizer(
|
|
parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
|
|
):
|
|
bn_params = [v for n, v in parameters if "bn" in n]
|
|
rest_params = [v for n, v in parameters if not "bn" in n]
|
|
|
|
params = [
|
|
{"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
|
|
{"params": rest_params, "weight_decay": weight_decay},
|
|
]
|
|
|
|
optimizer = torch.optim.RMSprop(
|
|
params,
|
|
lr=lr,
|
|
alpha=alpha,
|
|
weight_decay=weight_decay,
|
|
momentum=momentum,
|
|
eps=eps,
|
|
)
|
|
|
|
return optimizer
|
|
|
|
|
|
def lr_policy(lr_fn):
|
|
def _alr(optimizer, iteration, epoch):
|
|
lr = lr_fn(iteration, epoch)
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = lr
|
|
|
|
return lr
|
|
|
|
return _alr
|
|
|
|
|
|
def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
lr = base_lr
|
|
for s in steps:
|
|
if epoch >= s:
|
|
lr *= decay_factor
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn)
|
|
|
|
|
|
def lr_linear_policy(base_lr, warmup_length, epochs):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
es = epochs - warmup_length
|
|
lr = base_lr * (1 - (e / es))
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn)
|
|
|
|
|
|
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
es = epochs - warmup_length
|
|
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn)
|
|
|
|
|
|
def lr_exponential_policy(
|
|
base_lr,
|
|
warmup_length,
|
|
epochs,
|
|
final_multiplier=0.001,
|
|
decay_factor=None,
|
|
decay_step=1,
|
|
logger=None,
|
|
):
|
|
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
|
es = epochs - warmup_length
|
|
|
|
if decay_factor is not None:
|
|
epoch_decay = decay_factor
|
|
else:
|
|
epoch_decay = np.power(
|
|
2, np.log2(final_multiplier) / math.floor(es / decay_step)
|
|
)
|
|
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|