DeepLearningExamples/PyTorch/Classification/ConvNets/image_classification/training.py

580 lines
17 KiB
Python

# Copyright (c) 2018-2019, NVIDIA CORPORATION
# Copyright (c) 2017- Facebook, Inc
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from . import logger as log
from . import resnet as models
from . import utils
import dllogger
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
ACC_METADATA = {"unit": "%", "format": ":.2f"}
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}
LOSS_METADATA = {"format": ":.5f"}
class ModelAndLoss(nn.Module):
def __init__(
self,
arch,
loss,
pretrained_weights=None,
cuda=True,
fp16=False,
memory_format=torch.contiguous_format,
):
super(ModelAndLoss, self).__init__()
self.arch = arch
print("=> creating model '{}'".format(arch))
model = models.build_resnet(arch[0], arch[1], arch[2])
if pretrained_weights is not None:
print("=> using pre-trained model from a file '{}'".format(arch))
model.load_state_dict(pretrained_weights)
if cuda:
model = model.cuda().to(memory_format=memory_format)
if fp16:
model = network_to_half(model)
# define loss function (criterion) and optimizer
criterion = loss()
if cuda:
criterion = criterion.cuda()
self.model = model
self.loss = criterion
def forward(self, data, target):
output = self.model(data)
loss = self.loss(output, target)
return loss, output
def distributed(self):
self.model = DDP(self.model)
def load_model_state(self, state):
if not state is None:
self.model.load_state_dict(state)
def get_optimizer(
parameters,
fp16,
lr,
momentum,
weight_decay,
nesterov=False,
state=None,
static_loss_scale=1.0,
dynamic_loss_scale=False,
bn_weight_decay=False,
):
if bn_weight_decay:
print(" ! Weight decay applied to BN parameters ")
optimizer = torch.optim.SGD(
[v for n, v in parameters],
lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
)
else:
print(" ! Weight decay NOT applied to BN parameters ")
bn_params = [v for n, v in parameters if "bn" in n]
rest_params = [v for n, v in parameters if not "bn" in n]
print(len(bn_params))
print(len(rest_params))
optimizer = torch.optim.SGD(
[
{"params": bn_params, "weight_decay": 0},
{"params": rest_params, "weight_decay": weight_decay},
],
lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
)
if fp16:
optimizer = FP16_Optimizer(
optimizer,
static_loss_scale=static_loss_scale,
dynamic_loss_scale=dynamic_loss_scale,
verbose=False,
)
if not state is None:
optimizer.load_state_dict(state)
return optimizer
def lr_policy(lr_fn, logger=None):
if logger is not None:
logger.register_metric(
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
)
def _alr(optimizer, iteration, epoch):
lr = lr_fn(iteration, epoch)
if logger is not None:
logger.log_metric("lr", lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return _alr
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
lr = base_lr
for s in steps:
if epoch >= s:
lr *= decay_factor
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = base_lr * (1 - (e / es))
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_cosine_policy(base_lr, warmup_length, epochs, logger=None):
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
es = epochs - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
return lr
return lr_policy(_lr_fn, logger=logger)
def lr_exponential_policy(
base_lr, warmup_length, epochs, final_multiplier=0.001, logger=None
):
es = epochs - warmup_length
epoch_decay = np.power(2, np.log2(final_multiplier) / es)
def _lr_fn(iteration, epoch):
if epoch < warmup_length:
lr = base_lr * (epoch + 1) / warmup_length
else:
e = epoch - warmup_length
lr = base_lr * (epoch_decay ** e)
return lr
return lr_policy(_lr_fn, logger=logger)
def get_train_step(
model_and_loss, optimizer, fp16, use_amp=False, batch_size_multiplier=1
):
def _step(input, target, optimizer_step=True):
input_var = Variable(input)
target_var = Variable(target)
loss, output = model_and_loss(input_var, target_var)
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
else:
reduced_loss = loss.data
if fp16:
optimizer.backward(loss)
elif use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if optimizer_step:
opt = (
optimizer.optimizer
if isinstance(optimizer, FP16_Optimizer)
else optimizer
)
for param_group in opt.param_groups:
for param in param_group["params"]:
param.grad /= batch_size_multiplier
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
return reduced_loss
return _step
def train(
train_loader,
model_and_loss,
optimizer,
lr_scheduler,
fp16,
logger,
epoch,
use_amp=False,
prof=-1,
batch_size_multiplier=1,
register_metrics=True,
):
if register_metrics and logger is not None:
logger.register_metric(
"train.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
"train.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
"train.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"train.compute_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_train_step(
model_and_loss,
optimizer,
fp16,
use_amp=use_amp,
batch_size_multiplier=batch_size_multiplier,
)
model_and_loss.train()
end = time.time()
optimizer.zero_grad()
data_iter = enumerate(train_loader)
if logger is not None:
data_iter = logger.iteration_generator_wrapper(data_iter)
if prof > 0:
data_iter = utils.first_n(prof, data_iter)
for i, (input, target) in data_iter:
bs = input.size(0)
lr_scheduler(optimizer, i, epoch)
data_time = time.time() - end
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
loss = step(input, target, optimizer_step=optimizer_step)
it_time = time.time() - end
if logger is not None:
logger.log_metric("train.loss", to_python_float(loss), bs)
logger.log_metric("train.compute_ips", calc_ips(bs, it_time - data_time))
logger.log_metric("train.total_ips", calc_ips(bs, it_time))
logger.log_metric("train.data_time", data_time)
logger.log_metric("train.compute_time", it_time - data_time)
end = time.time()
def get_val_step(model_and_loss):
def _step(input, target):
input_var = Variable(input)
target_var = Variable(target)
with torch.no_grad():
loss, output = model_and_loss(input_var, target_var)
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
if torch.distributed.is_initialized():
reduced_loss = utils.reduce_tensor(loss.data)
prec1 = utils.reduce_tensor(prec1)
prec5 = utils.reduce_tensor(prec5)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
return reduced_loss, prec1, prec5
return _step
def validate(
val_loader, model_and_loss, fp16, logger, epoch, prof=-1, register_metrics=True
):
if register_metrics and logger is not None:
logger.register_metric(
"val.top1",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
"val.top5",
log.ACC_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=ACC_METADATA,
)
logger.register_metric(
"val.loss",
log.LOSS_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=LOSS_METADATA,
)
logger.register_metric(
"val.compute_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=IPS_METADATA,
)
logger.register_metric(
"val.total_ips",
log.PERF_METER(),
verbosity=dllogger.Verbosity.DEFAULT,
metadata=IPS_METADATA,
)
logger.register_metric(
"val.data_time",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"val.compute_latency",
log.PERF_METER(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"val.compute_latency_at100",
log.LAT_100(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"val.compute_latency_at99",
log.LAT_99(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
logger.register_metric(
"val.compute_latency_at95",
log.LAT_95(),
verbosity=dllogger.Verbosity.VERBOSE,
metadata=TIME_METADATA,
)
step = get_val_step(model_and_loss)
top1 = log.AverageMeter()
# switch to evaluate mode
model_and_loss.eval()
end = time.time()
data_iter = enumerate(val_loader)
if not logger is None:
data_iter = logger.iteration_generator_wrapper(data_iter, val=True)
if prof > 0:
data_iter = utils.first_n(prof, data_iter)
for i, (input, target) in data_iter:
bs = input.size(0)
data_time = time.time() - end
loss, prec1, prec5 = step(input, target)
it_time = time.time() - end
top1.record(to_python_float(prec1), bs)
if logger is not None:
logger.log_metric("val.top1", to_python_float(prec1), bs)
logger.log_metric("val.top5", to_python_float(prec5), bs)
logger.log_metric("val.loss", to_python_float(loss), bs)
logger.log_metric("val.compute_ips", calc_ips(bs, it_time - data_time))
logger.log_metric("val.total_ips", calc_ips(bs, it_time))
logger.log_metric("val.data_time", data_time)
logger.log_metric("val.compute_latency", it_time - data_time)
logger.log_metric("val.compute_latency_at95", it_time - data_time)
logger.log_metric("val.compute_latency_at99", it_time - data_time)
logger.log_metric("val.compute_latency_at100", it_time - data_time)
end = time.time()
return top1.get_val()
# Train loop {{{
def calc_ips(batch_size, time):
world_size = (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
tbs = world_size * batch_size
return tbs / time
def train_loop(
model_and_loss,
optimizer,
lr_scheduler,
train_loader,
val_loader,
fp16,
logger,
should_backup_checkpoint,
use_amp=False,
batch_size_multiplier=1,
best_prec1=0,
start_epoch=0,
end_epoch=0,
prof=-1,
skip_training=False,
skip_validation=False,
save_checkpoints=True,
checkpoint_dir="./",
checkpoint_filename="checkpoint.pth.tar",
):
prec1 = -1
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
for epoch in range(start_epoch, end_epoch):
if logger is not None:
logger.start_epoch()
if not skip_training:
train(
train_loader,
model_and_loss,
optimizer,
lr_scheduler,
fp16,
logger,
epoch,
use_amp=use_amp,
prof=prof,
register_metrics=epoch == start_epoch,
batch_size_multiplier=batch_size_multiplier,
)
if not skip_validation:
prec1, nimg = validate(
val_loader,
model_and_loss,
fp16,
logger,
epoch,
prof=prof,
register_metrics=epoch == start_epoch,
)
if logger is not None:
logger.end_epoch()
if save_checkpoints and (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
):
if not skip_validation:
is_best = logger.metrics["val.top1"]["meter"].get_epoch() > best_prec1
best_prec1 = max(
logger.metrics["val.top1"]["meter"].get_epoch(), best_prec1
)
else:
is_best = False
best_prec1 = 0
if should_backup_checkpoint(epoch):
backup_filename = "checkpoint-{}.pth.tar".format(epoch + 1)
else:
backup_filename = None
utils.save_checkpoint(
{
"epoch": epoch + 1,
"arch": model_and_loss.arch,
"state_dict": model_and_loss.model.state_dict(),
"best_prec1": best_prec1,
"optimizer": optimizer.state_dict(),
},
is_best,
checkpoint_dir=checkpoint_dir,
backup_filename=backup_filename,
filename=checkpoint_filename,
)
# }}}