773 lines
27 KiB
Python
773 lines
27 KiB
Python
|
import argparse
|
||
|
import os
|
||
|
import shutil
|
||
|
import time
|
||
|
import random
|
||
|
|
||
|
import torch
|
||
|
from torch.autograd import Variable
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.parallel
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import torch.distributed as dist
|
||
|
import torch.optim
|
||
|
import torch.utils.data
|
||
|
import torch.utils.data.distributed
|
||
|
import torchvision.transforms as transforms
|
||
|
import torchvision.datasets as datasets
|
||
|
import torchvision.models as models
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
try:
|
||
|
from apex.parallel import DistributedDataParallel as DDP
|
||
|
from apex.fp16_utils import *
|
||
|
except ImportError:
|
||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
|
||
|
|
||
|
def add_parser_arguments(parser):
|
||
|
model_names = sorted(name for name in models.__dict__
|
||
|
if name.islower() and not name.startswith("__")
|
||
|
and callable(models.__dict__[name]))
|
||
|
|
||
|
parser.add_argument('data', metavar='DIR',
|
||
|
help='path to dataset')
|
||
|
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
|
||
|
choices=model_names,
|
||
|
help='model architecture: ' +
|
||
|
' | '.join(model_names) +
|
||
|
' (default: resnet18)')
|
||
|
parser.add_argument('-j', '--workers', default=5, type=int, metavar='N',
|
||
|
help='number of data loading workers (default: 5)')
|
||
|
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||
|
help='number of total epochs to run')
|
||
|
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||
|
help='manual epoch number (useful on restarts)')
|
||
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||
|
metavar='N', help='mini-batch size (default: 256)')
|
||
|
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||
|
metavar='LR', help='initial learning rate')
|
||
|
parser.add_argument('--warmup', default=0, type=int,
|
||
|
metavar='E', help='number of warmup epochs')
|
||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||
|
help='momentum')
|
||
|
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||
|
metavar='W', help='weight decay (default: 1e-4)')
|
||
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||
|
metavar='N', help='print frequency (default: 10)')
|
||
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||
|
help='path to latest checkpoint (default: none)')
|
||
|
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||
|
help='evaluate model on validation set')
|
||
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||
|
help='use pre-trained model')
|
||
|
parser.add_argument('--pretrained-weights', default='', type=str, metavar='PATH',
|
||
|
help='file with weights')
|
||
|
|
||
|
parser.add_argument('--fp16', action='store_true',
|
||
|
help='Run model fp16 mode.')
|
||
|
parser.add_argument('--static-loss-scale', type=float, default=1,
|
||
|
help='Static loss scale, positive power of 2 values can improve fp16 convergence.')
|
||
|
parser.add_argument('--dynamic-loss-scale', action='store_true',
|
||
|
help='Use dynamic loss scaling. If supplied, this argument supersedes ' +
|
||
|
'--static-loss-scale.')
|
||
|
parser.add_argument('--prof', dest='prof', action='store_true',
|
||
|
help='Only run 10 iterations for profiling.')
|
||
|
|
||
|
parser.add_argument('--benchmark-training', dest='trainbench', action='store_true',
|
||
|
help='Run benchmarking of training')
|
||
|
parser.add_argument('--benchmark-inference', dest='inferbench', action='store_true',
|
||
|
help='Run benchmarking of training')
|
||
|
parser.add_argument('--bench-iterations', type=int, default=20, metavar='N',
|
||
|
help='Run N iterations while benchmarking (ignored when training and validation)')
|
||
|
parser.add_argument('--bench-warmup', type=int, default=20, metavar='N',
|
||
|
help='Number of warmup iterations for benchmarking')
|
||
|
|
||
|
|
||
|
parser.add_argument('--dist-url', default='tcp://localhost:23456', type=str,
|
||
|
help='url used to set up distributed training')
|
||
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||
|
help='distributed backend')
|
||
|
|
||
|
parser.add_argument('--world-size', default=1, type=int,
|
||
|
help='Number of GPUs to use. Can either be manually set ' +
|
||
|
'or automatically set by using \'python -m multiproc\'.')
|
||
|
parser.add_argument('--rank', default=0, type=int,
|
||
|
help='Used for multi-process training. Can either be manually set ' +
|
||
|
'or automatically set by using \'python -m multiproc\'.')
|
||
|
|
||
|
parser.add_argument('--seed', default=None, type=int,
|
||
|
help='random seed used for np and pytorch')
|
||
|
|
||
|
parser.add_argument('--gather-checkpoints', action='store_true',
|
||
|
help='Gather checkpoints throughout the training')
|
||
|
|
||
|
|
||
|
def main():
|
||
|
if args.trainbench or args.inferbench:
|
||
|
logger = BenchLogger
|
||
|
else:
|
||
|
logger = PrintLogger
|
||
|
|
||
|
train_net(args, logger)
|
||
|
|
||
|
|
||
|
def train_net(args, logger_cls):
|
||
|
exp_start_time = time.time()
|
||
|
global best_prec1
|
||
|
best_prec1 = 0
|
||
|
|
||
|
args.distributed = args.world_size > 1
|
||
|
args.gpu = 0
|
||
|
if args.distributed:
|
||
|
args.gpu = args.rank % torch.cuda.device_count()
|
||
|
|
||
|
if args.distributed:
|
||
|
torch.cuda.set_device(args.gpu)
|
||
|
dist.init_process_group(backend=args.dist_backend,
|
||
|
init_method=args.dist_url,
|
||
|
world_size=args.world_size,
|
||
|
rank=args.rank)
|
||
|
|
||
|
if args.seed is not None:
|
||
|
print("Using seed = {}".format(args.seed))
|
||
|
torch.manual_seed(args.seed + args.rank)
|
||
|
torch.cuda.manual_seed(args.seed + args.rank)
|
||
|
np.random.seed(seed=args.seed + args.rank)
|
||
|
random.seed(args.seed + args.rank)
|
||
|
|
||
|
def _worker_init_fn(id):
|
||
|
np.random.seed(seed=args.seed + args.rank + id)
|
||
|
random.seed(args.seed + args.rank + id)
|
||
|
else:
|
||
|
def _worker_init_fn(id):
|
||
|
pass
|
||
|
|
||
|
if args.fp16:
|
||
|
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
|
||
|
|
||
|
if args.static_loss_scale != 1.0:
|
||
|
if not args.fp16:
|
||
|
print("Warning: if --fp16 is not used, static_loss_scale will be ignored.")
|
||
|
|
||
|
pretrained_weights = None
|
||
|
if args.pretrained_weights:
|
||
|
if os.path.isfile(args.pretrained_weights):
|
||
|
print("=> loading pretrained weights from '{}'".format(args.pretrained_weights))
|
||
|
pretrained_weights = torch.load(args.pretrained_weights)
|
||
|
else:
|
||
|
print("=> no pretrained weights found at '{}'".format(args.resume))
|
||
|
|
||
|
# optionally resume from a checkpoint
|
||
|
if args.resume:
|
||
|
if os.path.isfile(args.resume):
|
||
|
print("=> loading checkpoint '{}'".format(args.resume))
|
||
|
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
|
||
|
args.start_epoch = checkpoint['epoch']
|
||
|
best_prec1 = checkpoint['best_prec1']
|
||
|
model_state = checkpoint['state_dict']
|
||
|
optimizer_state = checkpoint['optimizer']
|
||
|
print("=> loaded checkpoint '{}' (epoch {})"
|
||
|
.format(args.resume, checkpoint['epoch']))
|
||
|
else:
|
||
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
||
|
model_state = None
|
||
|
optimizer_state = None
|
||
|
else:
|
||
|
model_state = None
|
||
|
optimizer_state = None
|
||
|
|
||
|
model_and_loss = ModelAndLoss(args.arch, nn.CrossEntropyLoss,
|
||
|
args.pretrained, pretrained_weights=pretrained_weights, state=model_state,
|
||
|
cuda = True, fp16 = args.fp16, distributed = args.distributed)
|
||
|
|
||
|
# Create data loaders and optimizers as needed
|
||
|
|
||
|
if not (args.evaluate or args.inferbench):
|
||
|
optimizer = get_optimizer(model_and_loss.model.parameters(),
|
||
|
args.fp16,
|
||
|
args.lr, args.momentum, args.weight_decay,
|
||
|
state=optimizer_state,
|
||
|
static_loss_scale = args.static_loss_scale,
|
||
|
dynamic_loss_scale = args.dynamic_loss_scale)
|
||
|
|
||
|
train_loader = get_train_loader(args.data, args.batch_size, workers=args.workers, _worker_init_fn=_worker_init_fn)
|
||
|
train_loader_len = len(train_loader)
|
||
|
else:
|
||
|
train_loader_len = 0
|
||
|
|
||
|
if not args.trainbench:
|
||
|
val_loader = get_val_loader(args.data, args.batch_size, workers=args.workers, _worker_init_fn=_worker_init_fn)
|
||
|
val_loader_len = len(val_loader)
|
||
|
else:
|
||
|
val_loader_len = 0
|
||
|
|
||
|
|
||
|
if args.evaluate:
|
||
|
logger = logger_cls(train_loader_len, val_loader_len, args)
|
||
|
validate(val_loader, model_and_loss, args.fp16, logger, 0)
|
||
|
return
|
||
|
|
||
|
if args.trainbench:
|
||
|
model_and_loss.model.train()
|
||
|
logger = logger_cls("Train", args.world_size * args.batch_size, args.bench_warmup)
|
||
|
bench(get_train_step(model_and_loss, optimizer, args.fp16), train_loader,
|
||
|
args.bench_warmup, args.bench_iterations, args.fp16, logger, epoch_warmup = True)
|
||
|
return
|
||
|
|
||
|
if args.inferbench:
|
||
|
model_and_loss.model.eval()
|
||
|
logger = logger_cls("Inference", args.world_size * args.batch_size, args.bench_warmup)
|
||
|
bench(get_val_step(model_and_loss), val_loader,
|
||
|
args.bench_warmup, args.bench_iterations, args.fp16, logger, epoch_warmup = False)
|
||
|
return
|
||
|
|
||
|
logger = logger_cls(train_loader_len, val_loader_len, args)
|
||
|
train_loop(model_and_loss, optimizer, adjust_learning_rate(args), train_loader, val_loader, args.epochs,
|
||
|
args.fp16, logger, should_backup_checkpoint(args),
|
||
|
start_epoch = args.start_epoch, best_prec1 = best_prec1, prof=args.prof)
|
||
|
|
||
|
exp_duration = time.time() - exp_start_time
|
||
|
logger.experiment_timer(exp_duration)
|
||
|
logger.end_callback()
|
||
|
print("Experiment ended")
|
||
|
|
||
|
|
||
|
# Get Optimizer {{{
|
||
|
def get_optimizer(parameters, fp16, lr, momentum, weight_decay, state=None, static_loss_scale=1., dynamic_loss_scale=False):
|
||
|
optimizer = torch.optim.SGD(parameters, lr,
|
||
|
momentum=momentum,
|
||
|
weight_decay=weight_decay)
|
||
|
if fp16:
|
||
|
optimizer = FP16_Optimizer(optimizer,
|
||
|
static_loss_scale=static_loss_scale,
|
||
|
dynamic_loss_scale=dynamic_loss_scale)
|
||
|
|
||
|
if not state is None:
|
||
|
optimizer.load_state_dict(state)
|
||
|
|
||
|
return optimizer
|
||
|
# }}}
|
||
|
|
||
|
# ModelAndLoss {{{
|
||
|
class ModelAndLoss(nn.Module):
|
||
|
def __init__(self, arch, loss, pretrained, pretrained_weights=None, state=None, cuda=True, fp16=False, distributed=False):
|
||
|
super(ModelAndLoss, self).__init__()
|
||
|
self.arch = arch
|
||
|
if pretrained:
|
||
|
if pretrained_weights is None:
|
||
|
print("=> using pre-trained model from torchvision '{}'".format(arch))
|
||
|
model = models.__dict__[arch](pretrained=True)
|
||
|
else:
|
||
|
print("=> using pre-trained model from a file '{}'".format(arch))
|
||
|
model = models.__dict__[arch]()
|
||
|
model.load_state_dict(pretrained_weights)
|
||
|
else:
|
||
|
print("=> creating model '{}'".format(arch))
|
||
|
model = models.__dict__[arch]()
|
||
|
|
||
|
if cuda:
|
||
|
model = model.cuda()
|
||
|
if fp16:
|
||
|
model = network_to_half(model)
|
||
|
if distributed:
|
||
|
#shared param turns off bucketing in DDP, for lower latency runs this can improve perf
|
||
|
model = DDP(model, shared_param=True)
|
||
|
|
||
|
if not state is None:
|
||
|
model.load_state_dict(state)
|
||
|
|
||
|
# define loss function (criterion) and optimizer
|
||
|
criterion = loss()
|
||
|
|
||
|
if cuda:
|
||
|
criterion = criterion.cuda()
|
||
|
|
||
|
self.model = model
|
||
|
self.loss = criterion
|
||
|
|
||
|
def forward(self, data, target):
|
||
|
output = self.model(data)
|
||
|
loss = self.loss(output, target)
|
||
|
|
||
|
return loss, output
|
||
|
# }}}
|
||
|
|
||
|
# Train loop {{{
|
||
|
def train_loop(model_and_loss, optimizer, lr_scheduler, train_loader, val_loader, epochs, fp16, logger,
|
||
|
should_backup_checkpoint,
|
||
|
best_prec1 = 0, start_epoch = 0, prof = False):
|
||
|
|
||
|
for epoch in range(start_epoch, epochs):
|
||
|
if torch.distributed.is_initialized():
|
||
|
train_loader.sampler.set_epoch(epoch)
|
||
|
|
||
|
lr_scheduler(optimizer, epoch)
|
||
|
|
||
|
train(train_loader, model_and_loss, optimizer, fp16, logger, epoch, prof = prof)
|
||
|
|
||
|
prec1 = validate(val_loader, model_and_loss, fp16, logger, epoch, prof = prof)
|
||
|
|
||
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||
|
is_best = prec1 > best_prec1
|
||
|
best_prec1 = max(prec1, best_prec1)
|
||
|
|
||
|
if should_backup_checkpoint(epoch):
|
||
|
backup_filename = 'checkpoint-{}.pth.tar'.format(epoch + 1)
|
||
|
else:
|
||
|
backup_filename = None
|
||
|
save_checkpoint({
|
||
|
'epoch': epoch + 1,
|
||
|
'arch': model_and_loss.arch,
|
||
|
'state_dict': model_and_loss.model.state_dict(),
|
||
|
'best_prec1': best_prec1,
|
||
|
'optimizer' : optimizer.state_dict(),
|
||
|
}, is_best, backup_filename=backup_filename)
|
||
|
# }}}
|
||
|
|
||
|
# Data Loading functions {{{
|
||
|
def fast_collate(batch):
|
||
|
imgs = [img[0] for img in batch]
|
||
|
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
|
||
|
w = imgs[0].size[0]
|
||
|
h = imgs[0].size[1]
|
||
|
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
|
||
|
for i, img in enumerate(imgs):
|
||
|
nump_array = np.asarray(img, dtype=np.uint8)
|
||
|
tens = torch.from_numpy(nump_array)
|
||
|
if(nump_array.ndim < 3):
|
||
|
nump_array = np.expand_dims(nump_array, axis=-1)
|
||
|
nump_array = np.rollaxis(nump_array, 2)
|
||
|
|
||
|
tensor[i] += torch.from_numpy(nump_array)
|
||
|
|
||
|
return tensor, targets
|
||
|
|
||
|
|
||
|
def prefetched_loader(loader, fp16):
|
||
|
mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
|
||
|
std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
|
||
|
if fp16:
|
||
|
mean = mean.half()
|
||
|
std = std.half()
|
||
|
|
||
|
stream = torch.cuda.Stream()
|
||
|
first = True
|
||
|
|
||
|
for next_input, next_target in loader:
|
||
|
with torch.cuda.stream(stream):
|
||
|
next_input = next_input.cuda(async=True)
|
||
|
next_target = next_target.cuda(async=True)
|
||
|
if fp16:
|
||
|
next_input = next_input.half()
|
||
|
else:
|
||
|
next_input = next_input.float()
|
||
|
next_input = next_input.sub_(mean).div_(std)
|
||
|
|
||
|
if not first:
|
||
|
yield input, target
|
||
|
else:
|
||
|
first = False
|
||
|
|
||
|
torch.cuda.current_stream().wait_stream(stream)
|
||
|
input = next_input
|
||
|
target = next_target
|
||
|
|
||
|
yield input, target
|
||
|
|
||
|
|
||
|
def get_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None):
|
||
|
traindir = os.path.join(data_path, 'train')
|
||
|
train_dataset = datasets.ImageFolder(
|
||
|
traindir,
|
||
|
transforms.Compose([
|
||
|
transforms.RandomResizedCrop(224),
|
||
|
transforms.RandomHorizontalFlip(),
|
||
|
#transforms.ToTensor(), Too slow
|
||
|
#normalize,
|
||
|
]))
|
||
|
|
||
|
if torch.distributed.is_initialized():
|
||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||
|
else:
|
||
|
train_sampler = None
|
||
|
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
|
||
|
num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate, drop_last=True)
|
||
|
|
||
|
return train_loader
|
||
|
|
||
|
def get_val_loader(data_path, batch_size, workers=5, _worker_init_fn=None):
|
||
|
valdir = os.path.join(data_path, 'val')
|
||
|
|
||
|
val_loader = torch.utils.data.DataLoader(
|
||
|
datasets.ImageFolder(valdir, transforms.Compose([
|
||
|
transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
])),
|
||
|
batch_size=batch_size, shuffle=False,
|
||
|
num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True,
|
||
|
collate_fn=fast_collate)
|
||
|
|
||
|
return val_loader
|
||
|
# }}}
|
||
|
|
||
|
# Train val bench {{{
|
||
|
def get_train_step(model_and_loss, optimizer, fp16):
|
||
|
def _step(input, target):
|
||
|
input_var = Variable(input)
|
||
|
target_var = Variable(target)
|
||
|
loss, output = model_and_loss(input_var, target_var)
|
||
|
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
||
|
|
||
|
if torch.distributed.is_initialized():
|
||
|
reduced_loss = reduce_tensor(loss.data)
|
||
|
prec1 = reduce_tensor(prec1)
|
||
|
prec5 = reduce_tensor(prec5)
|
||
|
else:
|
||
|
reduced_loss = loss.data
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
if fp16:
|
||
|
optimizer.backward(loss)
|
||
|
else:
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
return reduced_loss, prec1, prec5
|
||
|
|
||
|
return _step
|
||
|
|
||
|
|
||
|
def bench(step, train_loader, warmup, iterations, fp16, logger, epoch_warmup = False):
|
||
|
step = timed_function(step)
|
||
|
|
||
|
if epoch_warmup:
|
||
|
print("Running first epoch for warmup, please wait")
|
||
|
|
||
|
for (input, target), dt in timed_generator(prefetched_loader(train_loader, fp16)):
|
||
|
_, bt = step(input, target)
|
||
|
|
||
|
print("Running benchmarked epoch")
|
||
|
|
||
|
for (input, target), dt in timed_generator(prefetched_loader(train_loader, fp16)):
|
||
|
_, bt = step(input, target)
|
||
|
logger.iter_callback({'data_time': dt, 'batch_time': bt})
|
||
|
|
||
|
if logger.i >= warmup + iterations:
|
||
|
break
|
||
|
|
||
|
logger.end_callback()
|
||
|
|
||
|
|
||
|
def train(train_loader, model_and_loss, optimizer, fp16, logger, epoch, prof=False):
|
||
|
|
||
|
step = get_train_step(model_and_loss, optimizer, fp16)
|
||
|
|
||
|
model_and_loss.model.train()
|
||
|
end = time.time()
|
||
|
|
||
|
for i, (input, target) in enumerate(prefetched_loader(train_loader, fp16)):
|
||
|
data_time = time.time() - end
|
||
|
|
||
|
if prof:
|
||
|
if i > 10:
|
||
|
break
|
||
|
|
||
|
loss, prec1, prec5 = step(input, target)
|
||
|
|
||
|
logger.train_iter_callback(epoch, i,
|
||
|
{'size' : input.size(0),
|
||
|
'top1' : to_python_float(prec1),
|
||
|
'top5' : to_python_float(prec5),
|
||
|
'loss' : to_python_float(loss),
|
||
|
'time' : time.time() - end,
|
||
|
'data' : data_time})
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
logger.train_epoch_callback(epoch)
|
||
|
|
||
|
|
||
|
def get_val_step(model_and_loss):
|
||
|
def _step(input, target):
|
||
|
input_var = Variable(input)
|
||
|
target_var = Variable(target)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
loss, output = model_and_loss(input_var, target_var)
|
||
|
|
||
|
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
||
|
|
||
|
if torch.distributed.is_initialized():
|
||
|
reduced_loss = reduce_tensor(loss.data)
|
||
|
prec1 = reduce_tensor(prec1)
|
||
|
prec5 = reduce_tensor(prec5)
|
||
|
else:
|
||
|
reduced_loss = loss.data
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
return reduced_loss, prec1, prec5
|
||
|
|
||
|
return _step
|
||
|
|
||
|
|
||
|
def validate(val_loader, model_and_loss, fp16, logger, epoch, prof=False):
|
||
|
|
||
|
step = get_val_step(model_and_loss)
|
||
|
|
||
|
top1 = AverageMeter()
|
||
|
# switch to evaluate mode
|
||
|
model_and_loss.model.eval()
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
for i, (input, target) in enumerate(prefetched_loader(val_loader, fp16)):
|
||
|
data_time = time.time() - end
|
||
|
if prof:
|
||
|
if i > 10:
|
||
|
break
|
||
|
|
||
|
loss, prec1, prec5 = step(input, target)
|
||
|
|
||
|
top1.update(to_python_float(prec1), input.size(0))
|
||
|
|
||
|
logger.val_iter_callback(epoch, i,
|
||
|
{'size' : input.size(0),
|
||
|
'top1' : to_python_float(prec1),
|
||
|
'top5' : to_python_float(prec5),
|
||
|
'loss' : to_python_float(loss),
|
||
|
'time' : time.time() - end,
|
||
|
'data' : data_time})
|
||
|
|
||
|
end = time.time()
|
||
|
|
||
|
logger.val_epoch_callback(epoch)
|
||
|
|
||
|
return top1.avg
|
||
|
|
||
|
# }}}
|
||
|
|
||
|
# Logging {{{
|
||
|
class BenchLogger(object):
|
||
|
def __init__(self, name, total_bs, warmup_iter):
|
||
|
self.name = name
|
||
|
self.data_time = AverageMeter()
|
||
|
self.batch_time = AverageMeter()
|
||
|
self.warmup_iter = warmup_iter
|
||
|
self.total_bs = total_bs
|
||
|
self.i = 0
|
||
|
|
||
|
def reset(self):
|
||
|
self.data_time.reset()
|
||
|
self.batch_time.reset()
|
||
|
self.i = 0
|
||
|
|
||
|
def iter_callback(self, d):
|
||
|
bt = d['batch_time']
|
||
|
dt = d['data_time']
|
||
|
if self.i >= self.warmup_iter:
|
||
|
self.data_time.update(dt)
|
||
|
self.batch_time.update(bt)
|
||
|
self.i += 1
|
||
|
|
||
|
print("Iter: [{}]\tbatch: {:.3f}\tdata: {:.3f}\timg/s (compute): {:.3f}\timg/s (total): {:.3f}".format(
|
||
|
self.i, dt + bt, dt,
|
||
|
self.total_bs / bt, self.total_bs / (bt + dt)))
|
||
|
|
||
|
def end_callback(self):
|
||
|
print("{} summary\tBatch Time: {:.3f}\tData Time: {:.3f}\timg/s (compute): {:.1f}\timg/s (total): {:.1f}".format(
|
||
|
self.name,
|
||
|
self.batch_time.avg, self.data_time.avg,
|
||
|
self.total_bs / self.batch_time.avg,
|
||
|
self.total_bs / (self.batch_time.avg + self.data_time.avg)))
|
||
|
|
||
|
|
||
|
class EpochLogger(object):
|
||
|
def __init__(self, name, total_iterations, args):
|
||
|
self.name = name
|
||
|
self.args = args
|
||
|
self.print_freq = args.print_freq
|
||
|
self.total_iterations = total_iterations
|
||
|
self.top1 = AverageMeter()
|
||
|
self.top5 = AverageMeter()
|
||
|
self.loss = AverageMeter()
|
||
|
self.time = AverageMeter()
|
||
|
self.data = AverageMeter()
|
||
|
|
||
|
def iter_callback(self, epoch, iteration, d):
|
||
|
self.top1.update(d['top1'], d['size'])
|
||
|
self.top5.update(d['top5'], d['size'])
|
||
|
self.loss.update(d['loss'], d['size'])
|
||
|
self.time.update(d['time'], d['size'])
|
||
|
self.data.update(d['data'], d['size'])
|
||
|
|
||
|
if iteration % self.print_freq == 0:
|
||
|
print('{0}:\t{1} [{2}/{3}]\t'
|
||
|
'Time {time.val:.3f} ({time.avg:.3f})\t'
|
||
|
'Data time {data.val:.3f} ({data.avg:.3f})\t'
|
||
|
'Speed {4:.3f} ({5:.3f})\t'
|
||
|
'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
|
||
|
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||
|
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||
|
self.name, epoch, iteration, self.total_iterations,
|
||
|
self.args.world_size * self.args.batch_size / self.time.val,
|
||
|
self.args.world_size * self.args.batch_size / self.time.avg,
|
||
|
time=self.time,
|
||
|
data=self.data,
|
||
|
loss=self.loss,
|
||
|
top1=self.top1,
|
||
|
top5=self.top5))
|
||
|
|
||
|
def epoch_callback(self, epoch):
|
||
|
print('{0} epoch {1} summary:\t'
|
||
|
'Time {time.avg:.3f}\t'
|
||
|
'Data time {data.avg:.3f}\t'
|
||
|
'Speed {2:.3f}\t'
|
||
|
'Loss {loss.avg:.4f}\t'
|
||
|
'Prec@1 {top1.avg:.3f}\t'
|
||
|
'Prec@5 {top5.avg:.3f}'.format(
|
||
|
self.name, epoch,
|
||
|
self.args.world_size * self.args.batch_size / self.time.avg,
|
||
|
time=self.time, data=self.data,
|
||
|
loss=self.loss, top1=self.top1, top5=self.top5))
|
||
|
|
||
|
self.top1.reset()
|
||
|
self.top5.reset()
|
||
|
self.loss.reset()
|
||
|
self.time.reset()
|
||
|
self.data.reset()
|
||
|
|
||
|
|
||
|
class PrintLogger(object):
|
||
|
def __init__(self, train_iterations, val_iterations, args):
|
||
|
self.train_logger = EpochLogger("Train", train_iterations, args)
|
||
|
self.val_logger = EpochLogger("Eval", val_iterations, args)
|
||
|
|
||
|
def train_iter_callback(self, epoch, iteration, d):
|
||
|
self.train_logger.iter_callback(epoch, iteration, d)
|
||
|
|
||
|
def train_epoch_callback(self, epoch):
|
||
|
self.train_logger.epoch_callback(epoch)
|
||
|
|
||
|
def val_iter_callback(self, epoch, iteration, d):
|
||
|
self.val_logger.iter_callback(epoch, iteration, d)
|
||
|
|
||
|
def val_epoch_callback(self, epoch):
|
||
|
self.val_logger.epoch_callback(epoch)
|
||
|
|
||
|
def experiment_timer(self, exp_duration):
|
||
|
print("Experiment took {} seconds".format(exp_duration))
|
||
|
|
||
|
def end_callback(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AverageMeter(object):
|
||
|
"""Computes and stores the average and current value"""
|
||
|
def __init__(self):
|
||
|
self.reset()
|
||
|
|
||
|
def reset(self):
|
||
|
self.val = 0
|
||
|
self.avg = 0
|
||
|
self.sum = 0
|
||
|
self.count = 0
|
||
|
|
||
|
def update(self, val, n=1):
|
||
|
self.val = val
|
||
|
self.sum += val * n
|
||
|
self.count += n
|
||
|
self.avg = self.sum / self.count
|
||
|
|
||
|
# }}}
|
||
|
|
||
|
def should_backup_checkpoint(args):
|
||
|
def _sbc(epoch):
|
||
|
return args.gather_checkpoints and (epoch < 10 or epoch % 10 == 0)
|
||
|
return _sbc
|
||
|
|
||
|
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', backup_filename=None):
|
||
|
if (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0:
|
||
|
print("SAVING")
|
||
|
torch.save(state, filename)
|
||
|
if is_best:
|
||
|
shutil.copyfile(filename, 'model_best.pth.tar')
|
||
|
if backup_filename is not None:
|
||
|
shutil.copyfile(filename, backup_filename)
|
||
|
|
||
|
def timed_generator(gen):
|
||
|
start = time.time()
|
||
|
for g in gen:
|
||
|
end = time.time()
|
||
|
t = end - start
|
||
|
yield g, t
|
||
|
start = time.time()
|
||
|
|
||
|
|
||
|
def timed_function(f):
|
||
|
def _timed_function(*args, **kwargs):
|
||
|
start = time.time()
|
||
|
ret = f(*args, **kwargs)
|
||
|
return ret, time.time() - start
|
||
|
return _timed_function
|
||
|
|
||
|
|
||
|
def adjust_learning_rate(args):
|
||
|
def _alr(optimizer, epoch):
|
||
|
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
||
|
if epoch < args.warmup:
|
||
|
lr = args.lr * (epoch + 1) / (args.warmup + 1)
|
||
|
|
||
|
else:
|
||
|
if epoch < 30:
|
||
|
p = 0
|
||
|
elif epoch < 60:
|
||
|
p = 1
|
||
|
elif epoch < 80:
|
||
|
p = 2
|
||
|
else:
|
||
|
p = 3
|
||
|
lr = args.lr * (0.1 ** p)
|
||
|
|
||
|
for param_group in optimizer.param_groups:
|
||
|
param_group['lr'] = lr
|
||
|
|
||
|
return _alr
|
||
|
|
||
|
|
||
|
def accuracy(output, target, topk=(1,)):
|
||
|
"""Computes the precision@k for the specified values of k"""
|
||
|
maxk = max(topk)
|
||
|
batch_size = target.size(0)
|
||
|
|
||
|
_, pred = output.topk(maxk, 1, True, True)
|
||
|
pred = pred.t()
|
||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||
|
|
||
|
res = []
|
||
|
for k in topk:
|
||
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||
|
return res
|
||
|
|
||
|
|
||
|
def reduce_tensor(tensor):
|
||
|
rt = tensor.clone()
|
||
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||
|
rt /= torch.distributed.get_world_size()
|
||
|
return rt
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||
|
|
||
|
add_parser_arguments(parser)
|
||
|
args = parser.parse_args()
|
||
|
cudnn.benchmark = True
|
||
|
|
||
|
main()
|