580 lines
17 KiB
Python
580 lines
17 KiB
Python
# Copyright (c) 2018-2019, NVIDIA CORPORATION
|
|
# Copyright (c) 2017- Facebook, Inc
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
from . import logger as log
|
|
from . import resnet as models
|
|
from . import utils
|
|
import dllogger
|
|
|
|
try:
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
from apex.fp16_utils import *
|
|
from apex import amp
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install apex from https://www.github.com/nvidia/apex to run this example."
|
|
)
|
|
|
|
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
|
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
|
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
|
LOSS_METADATA = {"format": ":.5f"}
|
|
|
|
|
|
class ModelAndLoss(nn.Module):
|
|
def __init__(
|
|
self,
|
|
arch,
|
|
loss,
|
|
pretrained_weights=None,
|
|
cuda=True,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
super(ModelAndLoss, self).__init__()
|
|
self.arch = arch
|
|
|
|
print("=> creating model '{}'".format(arch))
|
|
model = models.build_resnet(arch[0], arch[1], arch[2])
|
|
if pretrained_weights is not None:
|
|
print("=> using pre-trained model from a file '{}'".format(arch))
|
|
model.load_state_dict(pretrained_weights)
|
|
|
|
if cuda:
|
|
model = model.cuda().to(memory_format=memory_format)
|
|
if fp16:
|
|
model = network_to_half(model)
|
|
|
|
# define loss function (criterion) and optimizer
|
|
criterion = loss()
|
|
|
|
if cuda:
|
|
criterion = criterion.cuda()
|
|
|
|
self.model = model
|
|
self.loss = criterion
|
|
|
|
def forward(self, data, target):
|
|
output = self.model(data)
|
|
loss = self.loss(output, target)
|
|
|
|
return loss, output
|
|
|
|
def distributed(self):
|
|
self.model = DDP(self.model)
|
|
|
|
def load_model_state(self, state):
|
|
if not state is None:
|
|
self.model.load_state_dict(state)
|
|
|
|
|
|
def get_optimizer(
|
|
parameters,
|
|
fp16,
|
|
lr,
|
|
momentum,
|
|
weight_decay,
|
|
nesterov=False,
|
|
state=None,
|
|
static_loss_scale=1.0,
|
|
dynamic_loss_scale=False,
|
|
bn_weight_decay=False,
|
|
):
|
|
|
|
if bn_weight_decay:
|
|
print(" ! Weight decay applied to BN parameters ")
|
|
optimizer = torch.optim.SGD(
|
|
[v for n, v in parameters],
|
|
lr,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
nesterov=nesterov,
|
|
)
|
|
else:
|
|
print(" ! Weight decay NOT applied to BN parameters ")
|
|
bn_params = [v for n, v in parameters if "bn" in n]
|
|
rest_params = [v for n, v in parameters if not "bn" in n]
|
|
print(len(bn_params))
|
|
print(len(rest_params))
|
|
optimizer = torch.optim.SGD(
|
|
[
|
|
{"params": bn_params, "weight_decay": 0},
|
|
{"params": rest_params, "weight_decay": weight_decay},
|
|
],
|
|
lr,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
nesterov=nesterov,
|
|
)
|
|
if fp16:
|
|
optimizer = FP16_Optimizer(
|
|
optimizer,
|
|
static_loss_scale=static_loss_scale,
|
|
dynamic_loss_scale=dynamic_loss_scale,
|
|
verbose=False,
|
|
)
|
|
|
|
if not state is None:
|
|
optimizer.load_state_dict(state)
|
|
|
|
return optimizer
|
|
|
|
|
|
def lr_policy(lr_fn, logger=None):
|
|
if logger is not None:
|
|
logger.register_metric(
|
|
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
|
|
)
|
|
|
|
def _alr(optimizer, iteration, epoch):
|
|
lr = lr_fn(iteration, epoch)
|
|
|
|
if logger is not None:
|
|
logger.log_metric("lr", lr)
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = lr
|
|
|
|
return _alr
|
|
|
|
|
|
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
lr = base_lr
|
|
for s in steps:
|
|
if epoch >= s:
|
|
lr *= decay_factor
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
es = epochs - warmup_length
|
|
lr = base_lr * (1 - (e / es))
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
def lr_cosine_policy(base_lr, warmup_length, epochs, logger=None):
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
es = epochs - warmup_length
|
|
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
def lr_exponential_policy(
|
|
base_lr, warmup_length, epochs, final_multiplier=0.001, logger=None
|
|
):
|
|
es = epochs - warmup_length
|
|
epoch_decay = np.power(2, np.log2(final_multiplier) / es)
|
|
|
|
def _lr_fn(iteration, epoch):
|
|
if epoch < warmup_length:
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
else:
|
|
e = epoch - warmup_length
|
|
lr = base_lr * (epoch_decay ** e)
|
|
return lr
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
def get_train_step(
|
|
model_and_loss, optimizer, fp16, use_amp=False, batch_size_multiplier=1
|
|
):
|
|
def _step(input, target, optimizer_step=True):
|
|
input_var = Variable(input)
|
|
target_var = Variable(target)
|
|
loss, output = model_and_loss(input_var, target_var)
|
|
if torch.distributed.is_initialized():
|
|
reduced_loss = utils.reduce_tensor(loss.data)
|
|
else:
|
|
reduced_loss = loss.data
|
|
|
|
if fp16:
|
|
optimizer.backward(loss)
|
|
elif use_amp:
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
if optimizer_step:
|
|
opt = (
|
|
optimizer.optimizer
|
|
if isinstance(optimizer, FP16_Optimizer)
|
|
else optimizer
|
|
)
|
|
for param_group in opt.param_groups:
|
|
for param in param_group["params"]:
|
|
param.grad /= batch_size_multiplier
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
return reduced_loss
|
|
|
|
return _step
|
|
|
|
|
|
def train(
|
|
train_loader,
|
|
model_and_loss,
|
|
optimizer,
|
|
lr_scheduler,
|
|
fp16,
|
|
logger,
|
|
epoch,
|
|
use_amp=False,
|
|
prof=-1,
|
|
batch_size_multiplier=1,
|
|
register_metrics=True,
|
|
):
|
|
|
|
if register_metrics and logger is not None:
|
|
logger.register_metric(
|
|
"train.loss",
|
|
log.LOSS_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=LOSS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"train.compute_ips",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=IPS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"train.total_ips",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=IPS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"train.data_time",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"train.compute_time",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
|
|
step = get_train_step(
|
|
model_and_loss,
|
|
optimizer,
|
|
fp16,
|
|
use_amp=use_amp,
|
|
batch_size_multiplier=batch_size_multiplier,
|
|
)
|
|
|
|
model_and_loss.train()
|
|
end = time.time()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
data_iter = enumerate(train_loader)
|
|
if logger is not None:
|
|
data_iter = logger.iteration_generator_wrapper(data_iter)
|
|
if prof > 0:
|
|
data_iter = utils.first_n(prof, data_iter)
|
|
|
|
for i, (input, target) in data_iter:
|
|
bs = input.size(0)
|
|
lr_scheduler(optimizer, i, epoch)
|
|
data_time = time.time() - end
|
|
|
|
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
|
|
loss = step(input, target, optimizer_step=optimizer_step)
|
|
|
|
it_time = time.time() - end
|
|
|
|
if logger is not None:
|
|
logger.log_metric("train.loss", to_python_float(loss), bs)
|
|
logger.log_metric("train.compute_ips", calc_ips(bs, it_time - data_time))
|
|
logger.log_metric("train.total_ips", calc_ips(bs, it_time))
|
|
logger.log_metric("train.data_time", data_time)
|
|
logger.log_metric("train.compute_time", it_time - data_time)
|
|
|
|
end = time.time()
|
|
|
|
|
|
def get_val_step(model_and_loss):
|
|
def _step(input, target):
|
|
input_var = Variable(input)
|
|
target_var = Variable(target)
|
|
|
|
with torch.no_grad():
|
|
loss, output = model_and_loss(input_var, target_var)
|
|
|
|
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
|
|
|
if torch.distributed.is_initialized():
|
|
reduced_loss = utils.reduce_tensor(loss.data)
|
|
prec1 = utils.reduce_tensor(prec1)
|
|
prec5 = utils.reduce_tensor(prec5)
|
|
else:
|
|
reduced_loss = loss.data
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
return reduced_loss, prec1, prec5
|
|
|
|
return _step
|
|
|
|
|
|
def validate(
|
|
val_loader, model_and_loss, fp16, logger, epoch, prof=-1, register_metrics=True
|
|
):
|
|
if register_metrics and logger is not None:
|
|
logger.register_metric(
|
|
"val.top1",
|
|
log.ACC_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=ACC_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.top5",
|
|
log.ACC_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=ACC_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.loss",
|
|
log.LOSS_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=LOSS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.compute_ips",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=IPS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.total_ips",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
metadata=IPS_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.data_time",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.compute_latency",
|
|
log.PERF_METER(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.compute_latency_at100",
|
|
log.LAT_100(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.compute_latency_at99",
|
|
log.LAT_99(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
logger.register_metric(
|
|
"val.compute_latency_at95",
|
|
log.LAT_95(),
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
metadata=TIME_METADATA,
|
|
)
|
|
|
|
step = get_val_step(model_and_loss)
|
|
|
|
top1 = log.AverageMeter()
|
|
# switch to evaluate mode
|
|
model_and_loss.eval()
|
|
|
|
end = time.time()
|
|
|
|
data_iter = enumerate(val_loader)
|
|
if not logger is None:
|
|
data_iter = logger.iteration_generator_wrapper(data_iter, val=True)
|
|
if prof > 0:
|
|
data_iter = utils.first_n(prof, data_iter)
|
|
|
|
for i, (input, target) in data_iter:
|
|
bs = input.size(0)
|
|
data_time = time.time() - end
|
|
|
|
loss, prec1, prec5 = step(input, target)
|
|
|
|
it_time = time.time() - end
|
|
|
|
top1.record(to_python_float(prec1), bs)
|
|
if logger is not None:
|
|
logger.log_metric("val.top1", to_python_float(prec1), bs)
|
|
logger.log_metric("val.top5", to_python_float(prec5), bs)
|
|
logger.log_metric("val.loss", to_python_float(loss), bs)
|
|
logger.log_metric("val.compute_ips", calc_ips(bs, it_time - data_time))
|
|
logger.log_metric("val.total_ips", calc_ips(bs, it_time))
|
|
logger.log_metric("val.data_time", data_time)
|
|
logger.log_metric("val.compute_latency", it_time - data_time)
|
|
logger.log_metric("val.compute_latency_at95", it_time - data_time)
|
|
logger.log_metric("val.compute_latency_at99", it_time - data_time)
|
|
logger.log_metric("val.compute_latency_at100", it_time - data_time)
|
|
|
|
end = time.time()
|
|
|
|
return top1.get_val()
|
|
|
|
|
|
# Train loop {{{
|
|
def calc_ips(batch_size, time):
|
|
world_size = (
|
|
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
|
)
|
|
tbs = world_size * batch_size
|
|
return tbs / time
|
|
|
|
|
|
def train_loop(
|
|
model_and_loss,
|
|
optimizer,
|
|
lr_scheduler,
|
|
train_loader,
|
|
val_loader,
|
|
fp16,
|
|
logger,
|
|
should_backup_checkpoint,
|
|
use_amp=False,
|
|
batch_size_multiplier=1,
|
|
best_prec1=0,
|
|
start_epoch=0,
|
|
end_epoch=0,
|
|
prof=-1,
|
|
skip_training=False,
|
|
skip_validation=False,
|
|
save_checkpoints=True,
|
|
checkpoint_dir="./",
|
|
checkpoint_filename="checkpoint.pth.tar",
|
|
):
|
|
|
|
prec1 = -1
|
|
|
|
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
|
|
for epoch in range(start_epoch, end_epoch):
|
|
if logger is not None:
|
|
logger.start_epoch()
|
|
if not skip_training:
|
|
train(
|
|
train_loader,
|
|
model_and_loss,
|
|
optimizer,
|
|
lr_scheduler,
|
|
fp16,
|
|
logger,
|
|
epoch,
|
|
use_amp=use_amp,
|
|
prof=prof,
|
|
register_metrics=epoch == start_epoch,
|
|
batch_size_multiplier=batch_size_multiplier,
|
|
)
|
|
|
|
if not skip_validation:
|
|
prec1, nimg = validate(
|
|
val_loader,
|
|
model_and_loss,
|
|
fp16,
|
|
logger,
|
|
epoch,
|
|
prof=prof,
|
|
register_metrics=epoch == start_epoch,
|
|
)
|
|
if logger is not None:
|
|
logger.end_epoch()
|
|
|
|
if save_checkpoints and (
|
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
|
):
|
|
if not skip_validation:
|
|
is_best = logger.metrics["val.top1"]["meter"].get_epoch() > best_prec1
|
|
best_prec1 = max(
|
|
logger.metrics["val.top1"]["meter"].get_epoch(), best_prec1
|
|
)
|
|
else:
|
|
is_best = False
|
|
best_prec1 = 0
|
|
|
|
if should_backup_checkpoint(epoch):
|
|
backup_filename = "checkpoint-{}.pth.tar".format(epoch + 1)
|
|
else:
|
|
backup_filename = None
|
|
utils.save_checkpoint(
|
|
{
|
|
"epoch": epoch + 1,
|
|
"arch": model_and_loss.arch,
|
|
"state_dict": model_and_loss.model.state_dict(),
|
|
"best_prec1": best_prec1,
|
|
"optimizer": optimizer.state_dict(),
|
|
},
|
|
is_best,
|
|
checkpoint_dir=checkpoint_dir,
|
|
backup_filename=backup_filename,
|
|
filename=checkpoint_filename,
|
|
)
|
|
|
|
|
|
# }}}
|