2019-12-15 05:13:59 +01:00
|
|
|
# Copyright (c) 2018-2019, NVIDIA CORPORATION
|
|
|
|
# Copyright (c) 2017- Facebook, Inc
|
|
|
|
#
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
#
|
|
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
|
|
# list of conditions and the following disclaimer.
|
|
|
|
#
|
|
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
|
|
# and/or other materials provided with the distribution.
|
|
|
|
#
|
|
|
|
# * Neither the name of the copyright holder nor the names of its
|
|
|
|
# contributors may be used to endorse or promote products derived from
|
|
|
|
# this software without specific prior written permission.
|
|
|
|
#
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
2021-04-09 23:12:57 +02:00
|
|
|
import math
|
2019-05-27 15:24:14 +02:00
|
|
|
import time
|
2021-04-09 23:12:57 +02:00
|
|
|
|
2019-05-27 15:24:14 +02:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.autograd import Variable
|
|
|
|
from . import logger as log
|
2021-04-09 23:12:57 +02:00
|
|
|
from . import models
|
2019-05-27 15:24:14 +02:00
|
|
|
from . import utils
|
2019-12-20 14:46:11 +01:00
|
|
|
import dllogger
|
2020-06-27 09:32:20 +02:00
|
|
|
|
2021-04-09 23:12:57 +02:00
|
|
|
from .optimizers import get_sgd_optimizer, get_rmsprop_optimizer
|
|
|
|
from .models.common import EMA
|
2021-03-04 16:14:35 +01:00
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from torch.cuda.amp import autocast
|
2019-12-20 14:46:11 +01:00
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
ACC_METADATA = {"unit": "%", "format": ":.2f"}
|
|
|
|
IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
|
|
|
|
TIME_METADATA = {"unit": "s", "format": ":.5f"}
|
|
|
|
LOSS_METADATA = {"format": ":.5f"}
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ModelAndLoss(nn.Module):
|
2020-06-27 09:32:20 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
2021-04-09 23:12:57 +02:00
|
|
|
model,
|
2020-06-27 09:32:20 +02:00
|
|
|
loss,
|
|
|
|
cuda=True,
|
|
|
|
memory_format=torch.contiguous_format,
|
|
|
|
):
|
2019-05-27 15:24:14 +02:00
|
|
|
super(ModelAndLoss, self).__init__()
|
|
|
|
|
|
|
|
if cuda:
|
2020-06-27 09:32:20 +02:00
|
|
|
model = model.cuda().to(memory_format=memory_format)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
# define loss function (criterion) and optimizer
|
|
|
|
criterion = loss()
|
|
|
|
|
|
|
|
if cuda:
|
|
|
|
criterion = criterion.cuda()
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
self.loss = criterion
|
|
|
|
|
|
|
|
def forward(self, data, target):
|
|
|
|
output = self.model(data)
|
|
|
|
loss = self.loss(output, target)
|
|
|
|
|
|
|
|
return loss, output
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
def distributed(self, gpu_id):
|
|
|
|
self.model = DDP(self.model, device_ids=[gpu_id], output_device=gpu_id)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
def load_model_state(self, state):
|
|
|
|
if not state is None:
|
|
|
|
self.model.load_state_dict(state)
|
|
|
|
|
|
|
|
|
2021-04-09 23:12:57 +02:00
|
|
|
def get_optimizer(parameters, lr, args, state=None):
|
|
|
|
if args.optimizer == 'sgd':
|
|
|
|
optimizer = get_sgd_optimizer(parameters, lr, momentum=args.momentum,
|
|
|
|
weight_decay=args.weight_decay, nesterov=args.nesterov,
|
|
|
|
bn_weight_decay=args.bn_weight_decay)
|
|
|
|
elif args.optimizer == 'rmsprop':
|
|
|
|
optimizer = get_rmsprop_optimizer(parameters, lr, alpha=args.rmsprop_alpha, momentum=args.momentum,
|
|
|
|
weight_decay=args.weight_decay,
|
|
|
|
eps=args.rmsprop_eps,
|
|
|
|
bn_weight_decay=args.bn_weight_decay)
|
2019-05-27 15:24:14 +02:00
|
|
|
if not state is None:
|
|
|
|
optimizer.load_state_dict(state)
|
|
|
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
def lr_policy(lr_fn, logger=None):
|
|
|
|
if logger is not None:
|
2020-06-27 09:32:20 +02:00
|
|
|
logger.register_metric(
|
|
|
|
"lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
|
|
|
|
)
|
2019-12-20 14:46:11 +01:00
|
|
|
|
2019-05-27 15:24:14 +02:00
|
|
|
def _alr(optimizer, iteration, epoch):
|
|
|
|
lr = lr_fn(iteration, epoch)
|
|
|
|
|
|
|
|
if logger is not None:
|
2020-06-27 09:32:20 +02:00
|
|
|
logger.log_metric("lr", lr)
|
2019-05-27 15:24:14 +02:00
|
|
|
for param_group in optimizer.param_groups:
|
2020-06-27 09:32:20 +02:00
|
|
|
param_group["lr"] = lr
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
return _alr
|
|
|
|
|
|
|
|
|
|
|
|
def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
|
|
|
|
def _lr_fn(iteration, epoch):
|
|
|
|
if epoch < warmup_length:
|
|
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
|
|
else:
|
|
|
|
lr = base_lr
|
|
|
|
for s in steps:
|
|
|
|
if epoch >= s:
|
|
|
|
lr *= decay_factor
|
|
|
|
return lr
|
|
|
|
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
|
|
|
|
|
|
def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
|
|
|
|
def _lr_fn(iteration, epoch):
|
|
|
|
if epoch < warmup_length:
|
|
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
|
|
else:
|
|
|
|
e = epoch - warmup_length
|
|
|
|
es = epochs - warmup_length
|
2019-12-20 14:46:11 +01:00
|
|
|
lr = base_lr * (1 - (e / es))
|
2019-05-27 15:24:14 +02:00
|
|
|
return lr
|
|
|
|
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
|
|
|
2021-04-09 23:12:57 +02:00
|
|
|
def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr = 0, logger=None):
|
2019-05-27 15:24:14 +02:00
|
|
|
def _lr_fn(iteration, epoch):
|
|
|
|
if epoch < warmup_length:
|
|
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
|
|
else:
|
|
|
|
e = epoch - warmup_length
|
|
|
|
es = epochs - warmup_length
|
2021-04-09 23:12:57 +02:00
|
|
|
lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
|
2019-05-27 15:24:14 +02:00
|
|
|
return lr
|
|
|
|
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
def lr_exponential_policy(
|
2021-04-09 23:12:57 +02:00
|
|
|
base_lr, warmup_length, epochs, final_multiplier=0.001, decay_factor=None, decay_step=1, logger=None
|
2020-06-27 09:32:20 +02:00
|
|
|
):
|
2021-04-09 23:12:57 +02:00
|
|
|
"""Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
|
2019-05-27 15:24:14 +02:00
|
|
|
es = epochs - warmup_length
|
2021-04-09 23:12:57 +02:00
|
|
|
|
|
|
|
if decay_factor is not None:
|
|
|
|
epoch_decay = decay_factor
|
|
|
|
else:
|
|
|
|
epoch_decay = np.power(2, np.log2(final_multiplier) / math.floor(es/decay_step))
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
def _lr_fn(iteration, epoch):
|
|
|
|
if epoch < warmup_length:
|
|
|
|
lr = base_lr * (epoch + 1) / warmup_length
|
|
|
|
else:
|
|
|
|
e = epoch - warmup_length
|
2021-04-09 23:12:57 +02:00
|
|
|
lr = base_lr * (epoch_decay ** math.floor(e/decay_step))
|
2019-05-27 15:24:14 +02:00
|
|
|
return lr
|
|
|
|
|
|
|
|
return lr_policy(_lr_fn, logger=logger)
|
|
|
|
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
def get_train_step(
|
2021-03-04 16:14:35 +01:00
|
|
|
model_and_loss, optimizer, scaler, use_amp=False, batch_size_multiplier=1
|
2020-06-27 09:32:20 +02:00
|
|
|
):
|
2019-12-20 14:46:11 +01:00
|
|
|
def _step(input, target, optimizer_step=True):
|
2019-05-27 15:24:14 +02:00
|
|
|
input_var = Variable(input)
|
|
|
|
target_var = Variable(target)
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
with autocast(enabled=use_amp):
|
|
|
|
loss, output = model_and_loss(input_var, target_var)
|
|
|
|
loss /= batch_size_multiplier
|
|
|
|
if torch.distributed.is_initialized():
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data)
|
|
|
|
else:
|
|
|
|
reduced_loss = loss.data
|
|
|
|
|
|
|
|
scaler.scale(loss).backward()
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
if optimizer_step:
|
2021-03-04 16:14:35 +01:00
|
|
|
scaler.step(optimizer)
|
|
|
|
scaler.update()
|
2019-05-27 15:24:14 +02:00
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
2019-12-20 14:46:11 +01:00
|
|
|
return reduced_loss
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
return _step
|
|
|
|
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
def train(
|
|
|
|
train_loader,
|
|
|
|
model_and_loss,
|
|
|
|
optimizer,
|
2021-03-04 16:14:35 +01:00
|
|
|
scaler,
|
2020-06-27 09:32:20 +02:00
|
|
|
lr_scheduler,
|
|
|
|
logger,
|
|
|
|
epoch,
|
2021-04-09 23:12:57 +02:00
|
|
|
steps_per_epoch,
|
2021-03-04 16:14:35 +01:00
|
|
|
timeout_handler,
|
2021-04-09 23:12:57 +02:00
|
|
|
ema=None,
|
2020-06-27 09:32:20 +02:00
|
|
|
use_amp=False,
|
|
|
|
prof=-1,
|
|
|
|
batch_size_multiplier=1,
|
|
|
|
register_metrics=True,
|
|
|
|
):
|
2021-04-09 23:12:57 +02:00
|
|
|
interrupted = False
|
2019-05-27 15:24:14 +02:00
|
|
|
if register_metrics and logger is not None:
|
2020-06-27 09:32:20 +02:00
|
|
|
logger.register_metric(
|
|
|
|
"train.loss",
|
|
|
|
log.LOSS_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=LOSS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
|
|
|
"train.compute_ips",
|
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=IPS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
|
|
|
"train.total_ips",
|
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=IPS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
|
|
|
"train.data_time",
|
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
|
|
|
"train.compute_time",
|
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
|
|
|
|
step = get_train_step(
|
|
|
|
model_and_loss,
|
|
|
|
optimizer,
|
2021-03-04 16:14:35 +01:00
|
|
|
scaler=scaler,
|
2020-06-27 09:32:20 +02:00
|
|
|
use_amp=use_amp,
|
|
|
|
batch_size_multiplier=batch_size_multiplier,
|
|
|
|
)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
model_and_loss.train()
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
data_iter = enumerate(train_loader)
|
|
|
|
if logger is not None:
|
2021-04-13 17:00:33 +02:00
|
|
|
data_iter = logger.iteration_generator_wrapper(data_iter, mode='train')
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
for i, (input, target) in data_iter:
|
|
|
|
bs = input.size(0)
|
|
|
|
lr_scheduler(optimizer, i, epoch)
|
|
|
|
data_time = time.time() - end
|
|
|
|
|
|
|
|
optimizer_step = ((i + 1) % batch_size_multiplier) == 0
|
2019-12-20 14:46:11 +01:00
|
|
|
loss = step(input, target, optimizer_step=optimizer_step)
|
2021-04-09 23:12:57 +02:00
|
|
|
if ema is not None:
|
|
|
|
ema(model_and_loss, epoch*steps_per_epoch+i)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
it_time = time.time() - end
|
|
|
|
|
|
|
|
if logger is not None:
|
2021-03-04 16:14:35 +01:00
|
|
|
logger.log_metric("train.loss", loss.item(), bs)
|
2021-04-13 17:00:33 +02:00
|
|
|
logger.log_metric("train.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
|
|
|
logger.log_metric("train.total_ips", utils.calc_ips(bs, it_time))
|
2020-06-27 09:32:20 +02:00
|
|
|
logger.log_metric("train.data_time", data_time)
|
|
|
|
logger.log_metric("train.compute_time", it_time - data_time)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
end = time.time()
|
2021-04-09 23:12:57 +02:00
|
|
|
if prof > 0 and (i + 1 >= prof):
|
|
|
|
time.sleep(5)
|
|
|
|
break
|
|
|
|
if ((i+1) % 20 == 0) and timeout_handler.interrupted:
|
|
|
|
time.sleep(5)
|
|
|
|
interrupted = True
|
2021-03-04 16:14:35 +01:00
|
|
|
break
|
2019-05-27 15:24:14 +02:00
|
|
|
|
2021-04-09 23:12:57 +02:00
|
|
|
return interrupted
|
|
|
|
|
2019-05-27 15:24:14 +02:00
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
def get_val_step(model_and_loss, use_amp=False):
|
2019-05-27 15:24:14 +02:00
|
|
|
def _step(input, target):
|
|
|
|
input_var = Variable(input)
|
|
|
|
target_var = Variable(target)
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
with torch.no_grad(), autocast(enabled=use_amp):
|
2019-05-27 15:24:14 +02:00
|
|
|
loss, output = model_and_loss(input_var, target_var)
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
|
2019-05-27 15:24:14 +02:00
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
if torch.distributed.is_initialized():
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data)
|
|
|
|
prec1 = utils.reduce_tensor(prec1)
|
|
|
|
prec5 = utils.reduce_tensor(prec5)
|
|
|
|
else:
|
|
|
|
reduced_loss = loss.data
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
return reduced_loss, prec1, prec5
|
|
|
|
|
|
|
|
return _step
|
|
|
|
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
def validate(
|
2021-03-04 16:14:35 +01:00
|
|
|
val_loader,
|
|
|
|
model_and_loss,
|
|
|
|
logger,
|
|
|
|
epoch,
|
|
|
|
use_amp=False,
|
|
|
|
prof=-1,
|
|
|
|
register_metrics=True,
|
2021-04-09 23:12:57 +02:00
|
|
|
prefix="val",
|
2020-06-27 09:32:20 +02:00
|
|
|
):
|
2019-05-27 15:24:14 +02:00
|
|
|
if register_metrics and logger is not None:
|
2020-06-27 09:32:20 +02:00
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.top1",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.ACC_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=ACC_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.top5",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.ACC_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=ACC_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.loss",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.LOSS_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=LOSS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.compute_ips",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=IPS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.total_ips",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.DEFAULT,
|
|
|
|
metadata=IPS_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.data_time",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.compute_latency",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.PERF_METER(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.compute_latency_at100",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.LAT_100(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.compute_latency_at99",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.LAT_99(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
|
|
|
logger.register_metric(
|
2021-04-09 23:12:57 +02:00
|
|
|
f"{prefix}.compute_latency_at95",
|
2020-06-27 09:32:20 +02:00
|
|
|
log.LAT_95(),
|
|
|
|
verbosity=dllogger.Verbosity.VERBOSE,
|
|
|
|
metadata=TIME_METADATA,
|
|
|
|
)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
step = get_val_step(model_and_loss, use_amp=use_amp)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
top1 = log.AverageMeter()
|
|
|
|
# switch to evaluate mode
|
|
|
|
model_and_loss.eval()
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
data_iter = enumerate(val_loader)
|
|
|
|
if not logger is None:
|
2021-04-13 17:00:33 +02:00
|
|
|
data_iter = logger.iteration_generator_wrapper(data_iter, mode='val')
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
for i, (input, target) in data_iter:
|
|
|
|
bs = input.size(0)
|
|
|
|
data_time = time.time() - end
|
|
|
|
|
|
|
|
loss, prec1, prec5 = step(input, target)
|
|
|
|
|
|
|
|
it_time = time.time() - end
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
top1.record(prec1.item(), bs)
|
2019-05-27 15:24:14 +02:00
|
|
|
if logger is not None:
|
2021-04-09 23:12:57 +02:00
|
|
|
logger.log_metric(f"{prefix}.top1", prec1.item(), bs)
|
|
|
|
logger.log_metric(f"{prefix}.top5", prec5.item(), bs)
|
|
|
|
logger.log_metric(f"{prefix}.loss", loss.item(), bs)
|
2021-04-13 17:00:33 +02:00
|
|
|
logger.log_metric(f"{prefix}.compute_ips", utils.calc_ips(bs, it_time - data_time))
|
|
|
|
logger.log_metric(f"{prefix}.total_ips", utils.calc_ips(bs, it_time))
|
2021-04-09 23:12:57 +02:00
|
|
|
logger.log_metric(f"{prefix}.data_time", data_time)
|
|
|
|
logger.log_metric(f"{prefix}.compute_latency", it_time - data_time)
|
|
|
|
logger.log_metric(f"{prefix}.compute_latency_at95", it_time - data_time)
|
|
|
|
logger.log_metric(f"{prefix}.compute_latency_at99", it_time - data_time)
|
|
|
|
logger.log_metric(f"{prefix}.compute_latency_at100", it_time - data_time)
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
end = time.time()
|
2021-04-09 23:12:57 +02:00
|
|
|
if (prof > 0) and (i + 1 >= prof):
|
|
|
|
time.sleep(5)
|
|
|
|
break
|
2019-05-27 15:24:14 +02:00
|
|
|
|
|
|
|
return top1.get_val()
|
|
|
|
|
2019-12-20 14:46:11 +01:00
|
|
|
|
2019-05-27 15:24:14 +02:00
|
|
|
# Train loop {{{
|
2019-12-20 14:46:11 +01:00
|
|
|
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
def train_loop(
|
|
|
|
model_and_loss,
|
|
|
|
optimizer,
|
2021-03-04 16:14:35 +01:00
|
|
|
scaler,
|
2020-06-27 09:32:20 +02:00
|
|
|
lr_scheduler,
|
|
|
|
train_loader,
|
|
|
|
val_loader,
|
|
|
|
logger,
|
|
|
|
should_backup_checkpoint,
|
2021-04-09 23:12:57 +02:00
|
|
|
steps_per_epoch,
|
|
|
|
ema=None,
|
|
|
|
model_ema=None,
|
2020-06-27 09:32:20 +02:00
|
|
|
use_amp=False,
|
|
|
|
batch_size_multiplier=1,
|
|
|
|
best_prec1=0,
|
|
|
|
start_epoch=0,
|
|
|
|
end_epoch=0,
|
2021-04-09 23:12:57 +02:00
|
|
|
early_stopping_patience=-1,
|
2020-06-27 09:32:20 +02:00
|
|
|
prof=-1,
|
|
|
|
skip_training=False,
|
|
|
|
skip_validation=False,
|
|
|
|
save_checkpoints=True,
|
|
|
|
checkpoint_dir="./",
|
|
|
|
checkpoint_filename="checkpoint.pth.tar",
|
|
|
|
):
|
2019-05-27 15:24:14 +02:00
|
|
|
prec1 = -1
|
2021-04-09 23:12:57 +02:00
|
|
|
use_ema = (model_ema is not None) and (ema is not None)
|
|
|
|
|
|
|
|
if early_stopping_patience > 0:
|
|
|
|
epochs_since_improvement = 0
|
2021-04-13 17:00:33 +02:00
|
|
|
backup_prefix = checkpoint_filename[:-len("checkpoint.pth.tar")] if \
|
|
|
|
checkpoint_filename.endswith("checkpoint.pth.tar") else ""
|
|
|
|
|
2020-06-27 09:32:20 +02:00
|
|
|
print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
|
2021-03-04 16:14:35 +01:00
|
|
|
with utils.TimeoutHandler() as timeout_handler:
|
2021-04-09 23:12:57 +02:00
|
|
|
interrupted = False
|
2021-03-04 16:14:35 +01:00
|
|
|
for epoch in range(start_epoch, end_epoch):
|
|
|
|
if logger is not None:
|
|
|
|
logger.start_epoch()
|
|
|
|
if not skip_training:
|
2021-04-09 23:12:57 +02:00
|
|
|
interrupted = train(
|
2021-03-04 16:14:35 +01:00
|
|
|
train_loader,
|
|
|
|
model_and_loss,
|
|
|
|
optimizer,
|
|
|
|
scaler,
|
|
|
|
lr_scheduler,
|
|
|
|
logger,
|
|
|
|
epoch,
|
2021-04-09 23:12:57 +02:00
|
|
|
steps_per_epoch,
|
2021-03-04 16:14:35 +01:00
|
|
|
timeout_handler,
|
2021-04-09 23:12:57 +02:00
|
|
|
ema=ema,
|
2021-03-04 16:14:35 +01:00
|
|
|
use_amp=use_amp,
|
|
|
|
prof=prof,
|
|
|
|
register_metrics=epoch == start_epoch,
|
|
|
|
batch_size_multiplier=batch_size_multiplier,
|
|
|
|
)
|
2019-12-20 14:46:11 +01:00
|
|
|
|
|
|
|
if not skip_validation:
|
2021-03-04 16:14:35 +01:00
|
|
|
prec1, nimg = validate(
|
|
|
|
val_loader,
|
|
|
|
model_and_loss,
|
|
|
|
logger,
|
|
|
|
epoch,
|
|
|
|
use_amp=use_amp,
|
|
|
|
prof=prof,
|
|
|
|
register_metrics=epoch == start_epoch,
|
2020-06-27 09:32:20 +02:00
|
|
|
)
|
2021-04-09 23:12:57 +02:00
|
|
|
if use_ema:
|
|
|
|
model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()})
|
|
|
|
prec1, nimg = validate(
|
|
|
|
val_loader,
|
|
|
|
model_ema,
|
|
|
|
logger,
|
|
|
|
epoch,
|
|
|
|
prof=prof,
|
|
|
|
register_metrics=epoch == start_epoch,
|
|
|
|
prefix='val_ema'
|
|
|
|
)
|
|
|
|
|
|
|
|
if prec1 > best_prec1:
|
|
|
|
is_best = True
|
|
|
|
best_prec1 = prec1
|
|
|
|
else:
|
|
|
|
is_best = False
|
|
|
|
else:
|
|
|
|
is_best = True
|
|
|
|
best_prec1 = 0
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
if logger is not None:
|
|
|
|
logger.end_epoch()
|
|
|
|
|
|
|
|
if save_checkpoints and (
|
2021-04-09 23:12:57 +02:00
|
|
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
2021-03-04 16:14:35 +01:00
|
|
|
):
|
|
|
|
if should_backup_checkpoint(epoch):
|
2021-04-13 17:00:33 +02:00
|
|
|
backup_filename = "{}checkpoint-{}.pth.tar".format(backup_prefix, epoch + 1)
|
2021-03-04 16:14:35 +01:00
|
|
|
else:
|
|
|
|
backup_filename = None
|
2021-04-09 23:12:57 +02:00
|
|
|
checkpoint_state = {
|
|
|
|
"epoch": epoch + 1,
|
|
|
|
"state_dict": model_and_loss.model.state_dict(),
|
|
|
|
"best_prec1": best_prec1,
|
|
|
|
"optimizer": optimizer.state_dict(),
|
|
|
|
}
|
|
|
|
if use_ema:
|
|
|
|
checkpoint_state["state_dict_ema"] = ema.state_dict()
|
|
|
|
|
2021-03-04 16:14:35 +01:00
|
|
|
utils.save_checkpoint(
|
2021-04-09 23:12:57 +02:00
|
|
|
checkpoint_state,
|
2021-03-04 16:14:35 +01:00
|
|
|
is_best,
|
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
|
backup_filename=backup_filename,
|
|
|
|
filename=checkpoint_filename,
|
|
|
|
)
|
2021-04-09 23:12:57 +02:00
|
|
|
if early_stopping_patience > 0:
|
|
|
|
if not is_best:
|
|
|
|
epochs_since_improvement += 1
|
|
|
|
else:
|
|
|
|
epochs_since_improvement = 0
|
|
|
|
if epochs_since_improvement >= early_stopping_patience:
|
|
|
|
break
|
|
|
|
if interrupted:
|
2021-03-04 16:14:35 +01:00
|
|
|
break
|
2019-05-27 15:24:14 +02:00
|
|
|
# }}}
|