DeepLearningExamples/MxNet/Classification/RN50v1.5/fit.py

550 lines
21 KiB
Python
Raw Permalink Normal View History

# Copyright 2017-2018 The Apache Software Foundation
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -----------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-10-21 19:20:40 +02:00
""" train fit utility """
import logging
import math
2021-04-07 17:46:50 +02:00
import os
2019-10-21 19:20:40 +02:00
import random
2021-04-07 17:46:50 +02:00
import sys
import time
2019-10-21 19:20:40 +02:00
from itertools import starmap
2021-04-07 17:46:50 +02:00
import dllogger
2019-10-21 19:20:40 +02:00
import horovod.mxnet as hvd
2021-04-07 17:46:50 +02:00
import mxnet as mx
2019-10-21 19:20:40 +02:00
import mxnet.contrib.amp as amp
2021-04-07 17:46:50 +02:00
import numpy as np
2019-10-21 19:20:40 +02:00
from mxnet import autograd as ag
from mxnet import gluon
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
import data
2021-04-07 17:46:50 +02:00
from benchmarking import BenchmarkingDataIter
from global_metrics import CompositeMeter, MaxMeter, MinMeter, AvgMeter, PercentileMeter
2019-10-21 19:20:40 +02:00
def add_fit_args(parser):
def int_list(x):
return list(map(int, x.split(',')))
2019-10-21 19:20:40 +02:00
def float_list(x):
return list(map(float, x.split(',')))
2019-10-21 19:20:40 +02:00
train = parser.add_argument_group('Training')
train.add_argument('--mode', default='train_val', choices=('train_val', 'train', 'val', 'pred'),
help='mode')
train.add_argument('--seed', type=int, default=None,
help='random seed')
train.add_argument('--gpus', type=int_list, default=[0],
help='list of gpus to run, e.g. 0 or 0,2,5')
train.add_argument('--kv-store', type=str, default='device', choices=('device', 'horovod'),
help='key-value store type')
2019-10-21 19:20:40 +02:00
train.add_argument('--dtype', type=str, default='float16', choices=('float32', 'float16'),
help='precision')
train.add_argument('--amp', action='store_true',
help='If enabled, turn on AMP (Automatic Mixed Precision)')
train.add_argument('--batch-size', type=int, default=192,
help='the batch size')
train.add_argument('--num-epochs', type=int, default=90,
help='number of epochs')
2021-04-07 17:46:50 +02:00
train.add_argument('--run-epochs', type=int, default=-1,
help='number of epochs to run in single run')
train.add_argument('--lr', type=float, default=0.1,
help='initial learning rate')
2019-10-21 19:20:40 +02:00
train.add_argument('--lr-schedule', choices=('multistep', 'cosine'), default='cosine',
help='learning rate schedule')
train.add_argument('--lr-factor', type=float, default=0.256,
help='the ratio to reduce lr on each step')
2019-10-21 19:20:40 +02:00
train.add_argument('--lr-steps', type=float_list, default=[],
help='the epochs to reduce the lr, e.g. 30,60')
2019-10-21 19:20:40 +02:00
train.add_argument('--warmup-epochs', type=int, default=5,
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
2019-10-21 19:20:40 +02:00
train.add_argument('--mom', type=float, default=0.875,
help='momentum for sgd')
2019-10-21 19:20:40 +02:00
train.add_argument('--wd', type=float, default=1 / 32768,
help='weight decay for sgd')
2019-10-21 19:20:40 +02:00
train.add_argument('--label-smoothing', type=float, default=0.1,
help='label smoothing factor')
train.add_argument('--mixup', type=float, default=0,
help='alpha parameter for mixup (if 0 then mixup is not applied)')
train.add_argument('--disp-batches', type=int, default=20,
help='show progress for every n batches')
2019-10-21 19:20:40 +02:00
train.add_argument('--model-prefix', type=str, default='model',
help='model checkpoint prefix')
train.add_argument('--save-frequency', type=int, default=-1,
help='frequency of saving model in epochs (--model-prefix must be specified). '
'If -1 then save only best model. If 0 then do not save anything.')
train.add_argument('--begin-epoch', type=int, default=0,
help='start the model from an epoch')
train.add_argument('--load', help='checkpoint to load')
train.add_argument('--test-io', action='store_true',
help='test reading speed without training')
train.add_argument('--test-io-mode', default='train', choices=('train', 'val'),
help='data to test')
train.add_argument('--log', type=str, default='log.log',
help='file where to save the log from the experiment')
2021-04-07 17:46:50 +02:00
train.add_argument('--dllogger-log', type=str, default='dllogger_log.log',
help='file where to save the dllogger log from the experiment')
train.add_argument('--workspace', type=str, default='./',
help='path to directory where results will be stored')
train.add_argument('--no-metrics', action='store_true',
help='do not calculate evaluation metrics (for benchmarking)')
2019-10-21 19:20:40 +02:00
train.add_argument('--benchmark-iters', type=int, default=None,
help='run only benchmark-iters iterations from each epoch')
return train
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def get_epoch_size(args, kv):
return math.ceil(args.num_examples / args.batch_size)
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def get_lr_scheduler(args):
def multistep_schedule(x):
lr = args.lr * (args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
warmup_coeff = min(1, x / args.warmup_epochs)
return warmup_coeff * lr
def cosine_schedule(x):
steps = args.lr_steps
if not steps or steps[0] > args.warmup_epochs:
steps = [args.warmup_epochs] + steps
elif not steps or steps[0] != 0:
steps = [0] + steps
if steps[-1] != args.num_epochs:
steps.append(args.num_epochs)
if x < args.warmup_epochs:
return args.lr * x / args.warmup_epochs
for i, (step, next_step) in enumerate(zip(steps, steps[1:])):
if next_step > x:
return args.lr * 0.5 * (1 + math.cos(math.pi * (x - step) / (next_step - step))) * (args.lr_factor ** i)
return 0
schedules = {
'multistep': multistep_schedule,
'cosine': cosine_schedule,
}
return schedules[args.lr_schedule]
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def load_model(args, model):
if args.load is None:
return False
model.load_parameters(args.load)
logging.info('Loaded model {}'.format(args.load))
return True
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def save_checkpoint(net, epoch, top1, best_acc, model_prefix, save_frequency, kvstore):
if model_prefix is None or save_frequency == 0 or ('horovod' in kvstore and hvd.rank() != 0):
return
if save_frequency > 0 and (epoch + 1) % save_frequency == 0:
fname = '{}_{:04}.params'.format(model_prefix, epoch)
net.save_parameters(fname)
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
if top1 > best_acc:
fname = '{}_best.params'.format(model_prefix)
net.save_parameters(fname)
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
def model_pred(args, model, image):
from imagenet_classes import classes
output = model(image.reshape(-1, *image.shape))[0].softmax().as_in_context(mx.cpu())
top = output.argsort(is_ascend=False)[:10]
for i, ind in enumerate(top):
ind = int(ind.asscalar())
logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1, output[ind].asscalar() * 100, classes[ind]))
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def reduce_metrics(args, metrics, kvstore):
if 'horovod' not in kvstore or not metrics[0] or hvd.size() == 1:
return metrics
m = mx.ndarray.array(metrics[1], ctx=mx.gpu(args.gpus[0]))
reduced = hvd.allreduce(m)
values = reduced.as_in_context(mx.cpu()).asnumpy().tolist()
return (metrics[0], values)
2021-04-07 17:46:50 +02:00
def model_score(args, net, val_data, metric, kvstore):
2019-10-21 19:20:40 +02:00
if val_data is None:
logging.info('Omitting validation: no data')
return [], []
if not isinstance(metric, mx.metric.EvalMetric):
metric = mx.metric.create(metric)
metric.reset()
val_data.reset()
total_batch_size = val_data.batch_size * val_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
durations = []
tic = time.time()
outputs = []
for batches in val_data:
# synchronize to previous iteration
for o in outputs:
o.wait_to_read()
data = [b.data[0] for b in batches]
label = [b.label[0][:len(b.data[0]) - b.pad] for b in batches if len(b.data[0]) != b.pad]
outputs = [net(X) for X, b in zip(data, batches)]
outputs = [o[:len(b.data[0]) - b.pad] for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
metric.update(label, outputs)
durations.append(time.time() - tic)
tic = time.time()
metric = reduce_metrics(args, metric.get_global(), kvstore)
durations = durations[min(len(durations) // 10, 100):]
2021-04-07 17:46:50 +02:00
duration_stats = {
'ips': total_batch_size / np.mean(durations),
'latency_avg': np.mean(durations),
}
return metric, duration_stats, durations
2019-10-21 19:20:40 +02:00
class ScalarMetric(mx.metric.Loss):
def update(self, _, scalar):
self.sum_metric += scalar
self.global_sum_metric += scalar
self.num_inst += 1
self.global_num_inst += 1
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
def label_smoothing(labels, classes, eta):
return labels.one_hot(classes, on_value=1 - eta + eta / classes, off_value=eta / classes)
2021-04-07 17:46:50 +02:00
def model_fit(args, net, train_data, eval_metric, optimizer,
optimizer_params, lr_scheduler, eval_data, global_metrics, kvstore, kv,
begin_epoch, num_epoch, run_epoch, model_prefix):
2019-10-21 19:20:40 +02:00
if not isinstance(eval_metric, mx.metric.EvalMetric):
eval_metric = mx.metric.create(eval_metric)
loss_metric = ScalarMetric()
if 'horovod' in kvstore:
trainer = hvd.DistributedTrainer(net.collect_params(), optimizer, optimizer_params)
else:
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params,
kvstore=kv, update_on_kvstore=False)
if args.amp:
amp.init_trainer(trainer)
2021-04-07 17:46:50 +02:00
2019-10-21 19:20:40 +02:00
sparse_label_loss = (args.label_smoothing == 0 and args.mixup == 0)
loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
loss.hybridize(static_shape=True, static_alloc=True)
local_batch_size = train_data.batch_size
total_batch_size = local_batch_size * train_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
durations = []
epoch_size = get_epoch_size(args, kv)
2021-04-07 17:46:50 +02:00
run_epoch = num_epoch if (run_epoch == -1) else (begin_epoch + run_epoch)
2019-10-21 19:20:40 +02:00
def transform_data(images, labels):
if args.mixup != 0:
2021-04-07 17:46:50 +02:00
coeffs = mx.nd.array(np.random.beta(args.mixup, args.mixup, size=images.shape[0])).as_in_context(
images.context)
2019-10-21 19:20:40 +02:00
image_coeffs = coeffs.astype(images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
ret_images = image_coeffs * images + (1 - image_coeffs) * images[::-1]
ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
label_coeffs = coeffs.reshape(*coeffs.shape, 1)
ret_labels = label_coeffs * ret_labels + (1 - label_coeffs) * ret_labels[::-1]
else:
ret_images = images
if not sparse_label_loss:
ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
else:
ret_labels = labels
return ret_images, ret_labels
2021-04-07 17:46:50 +02:00
i = -1
2019-10-21 19:20:40 +02:00
best_accuracy = -1
2021-04-07 17:46:50 +02:00
for epoch in range(begin_epoch, min(run_epoch, num_epoch)):
2019-10-21 19:20:40 +02:00
tic = time.time()
2021-04-07 17:46:50 +02:00
btic = time.time()
etic = time.time()
2019-10-21 19:20:40 +02:00
train_data.reset()
eval_metric.reset()
loss_metric.reset()
logging.info('Starting epoch {}'.format(epoch))
outputs = []
for i, batches in enumerate(train_data):
# synchronize to previous iteration
2021-04-07 17:46:50 +02:00
#for o in outputs:
# o.wait_to_read()
2019-10-21 19:20:40 +02:00
trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
data = [b.data[0] for b in batches]
label = [b.label[0].as_in_context(b.data[0].context) for b in batches]
orig_label = label
data, label = zip(*starmap(transform_data, zip(data, label)))
outputs = []
Ls = []
with ag.record():
for x, y in zip(data, label):
z = net(x)
L = loss(z, y)
# store the loss and do backward after we have done forward
# on all GPUs for better speed on multiple GPUs.
Ls.append(L)
outputs.append(z)
if args.amp:
with amp.scale_loss(Ls, trainer) as scaled_loss:
ag.backward(scaled_loss)
else:
ag.backward(Ls)
if 'horovod' in kvstore:
trainer.step(local_batch_size)
else:
trainer.step(total_batch_size)
2021-04-07 17:46:50 +02:00
loss_metric.update(..., np.mean([l.asnumpy() for l in Ls]).item())
2019-10-21 19:20:40 +02:00
if args.disp_batches and not (i + 1) % args.disp_batches:
2021-04-07 17:46:50 +02:00
dllogger_it_data = {
'train.loss': loss_metric.get()[1],
'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
'train.lr': trainer.learning_rate
}
dllogger.log((epoch, i), data=dllogger_it_data)
2019-10-21 19:20:40 +02:00
loss_metric.reset_local()
btic = time.time()
durations.append(time.time() - tic)
tic = time.time()
durations = durations[min(len(durations) // 10, 100):]
2021-04-07 17:46:50 +02:00
dllogger_epoch_data = {
'train.loss': loss_metric.get_global()[1],
'train.ips': total_batch_size / np.mean(durations)
2021-04-07 17:46:50 +02:00
}
2019-10-21 19:20:40 +02:00
if args.mode == 'train_val':
logging.info('Validating epoch {}'.format(epoch))
2021-04-07 17:46:50 +02:00
score, duration_stats, _ = model_score(args, net, eval_data, eval_metric, kvstore)
dllogger_epoch_data.update(
starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score))
)
dllogger_epoch_data.update(
starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
)
2019-10-21 19:20:40 +02:00
score = dict(zip(*score))
accuracy = score.get('accuracy', -1)
save_checkpoint(net, epoch, accuracy, best_accuracy, model_prefix, args.save_frequency, kvstore)
best_accuracy = max(best_accuracy, accuracy)
2021-04-07 17:46:50 +02:00
global_metrics.update_dict(dllogger_epoch_data)
dllogger.log(step=(epoch,), data=dllogger_epoch_data)
2019-10-21 19:20:40 +02:00
def fit(args, model, data_loader):
"""
train a model
args : argparse returns
2019-10-21 19:20:40 +02:00
model : the the neural network model
data_loader : function that returns the train and val data iterators
"""
start_time = time.time()
2019-10-21 19:20:40 +02:00
# select gpu for horovod process
if 'horovod' in args.kv_store:
args.gpus = [args.gpus[hvd.local_rank()]]
2019-10-21 19:20:40 +02:00
if args.amp:
amp.init()
if args.seed is not None:
logging.info('Setting seeds to {}'.format(args.seed))
random.seed(args.seed)
np.random.seed(args.seed)
mx.random.seed(args.seed)
2019-10-21 19:20:40 +02:00
# kvstore
if 'horovod' in args.kv_store:
kv = None
rank = hvd.rank()
num_workers = hvd.size()
else:
kv = mx.kvstore.create(args.kv_store)
rank = kv.rank
num_workers = kv.num_workers
if args.test_io:
2019-10-21 19:20:40 +02:00
train, val = data_loader(args, kv)
if args.test_io_mode == 'train':
data_iter = train
else:
data_iter = val
tic = time.time()
2019-10-21 19:20:40 +02:00
for i, batch in enumerate(data_iter):
if isinstance(batch, list):
for b in batch:
for j in b.data:
j.wait_to_read()
else:
for j in batch.data:
j.wait_to_read()
if (i + 1) % args.disp_batches == 0:
2019-10-21 19:20:40 +02:00
logging.info('Batch [{}]\tSpeed: {:.2f} samples/sec'.format(
i, args.disp_batches * args.batch_size / (time.time() - tic)))
tic = time.time()
return
2019-10-21 19:20:40 +02:00
if not load_model(args, model):
# all initializers should be specified in the model definition.
# if not, this will raise an error
model.initialize(mx.init.Initializer())
# devices for training
2019-10-21 19:20:40 +02:00
devs = list(map(mx.gpu, args.gpus))
model.collect_params().reset_ctx(devs)
2019-10-21 19:20:40 +02:00
if args.mode == 'pred':
logging.info('Infering image {}'.format(args.data_pred))
model_pred(args, model, data.load_image(args, args.data_pred, devs[0]))
return
2019-10-21 19:20:40 +02:00
# learning rate
lr_scheduler = get_lr_scheduler(args)
optimizer_params = {
2019-10-21 19:20:40 +02:00
'learning_rate': 0,
'wd': args.wd,
2019-10-21 19:20:40 +02:00
'multi_precision': True,
}
# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom
# evaluation metrices
if not args.no_metrics:
2019-10-21 19:20:40 +02:00
eval_metrics = ['accuracy']
eval_metrics.append(mx.metric.create(
'top_k_accuracy', top_k=5))
else:
eval_metrics = []
2019-10-21 19:20:40 +02:00
train, val = data_loader(args, kv)
train = BenchmarkingDataIter(train, args.benchmark_iters)
if val is not None:
val = BenchmarkingDataIter(val, args.benchmark_iters)
2019-10-21 19:20:40 +02:00
if 'horovod' in args.kv_store:
# Fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)
2021-04-07 17:46:50 +02:00
global_metrics = CompositeMeter()
if args.mode in ['train_val', 'train']:
global_metrics.register_metric('train.loss', MinMeter())
global_metrics.register_metric('train.ips', AvgMeter())
if args.mode in ['train_val', 'val']:
global_metrics.register_metric('val.accuracy', MaxMeter())
global_metrics.register_metric('val.top_k_accuracy_5', MaxMeter())
global_metrics.register_metric('val.ips', AvgMeter())
global_metrics.register_metric('val.latency_avg', AvgMeter())
if args.mode in ['val']:
global_metrics.register_metric('val.latency_50', PercentileMeter(50))
global_metrics.register_metric('val.latency_90', PercentileMeter(90))
global_metrics.register_metric('val.latency_95', PercentileMeter(95))
global_metrics.register_metric('val.latency_99', PercentileMeter(99))
global_metrics.register_metric('val.latency_100', PercentileMeter(100))
2019-10-21 19:20:40 +02:00
# run
if args.mode in ['train_val', 'train']:
model_fit(
args,
model,
train,
begin_epoch=args.begin_epoch,
num_epoch=args.num_epochs,
2021-04-07 17:46:50 +02:00
run_epoch=args.run_epochs,
2019-10-21 19:20:40 +02:00
eval_data=val,
eval_metric=eval_metrics,
2021-04-07 17:46:50 +02:00
global_metrics=global_metrics,
2019-10-21 19:20:40 +02:00
kvstore=args.kv_store,
kv=kv,
optimizer=args.optimizer,
optimizer_params=optimizer_params,
lr_scheduler=lr_scheduler,
2021-04-07 17:46:50 +02:00
model_prefix=os.path.join(args.workspace, args.model_prefix),
2019-10-21 19:20:40 +02:00
)
elif args.mode == 'val':
for epoch in range(args.num_epochs): # loop for benchmarking
2021-04-07 17:46:50 +02:00
score, duration_stats, durations = model_score(args, model, val, eval_metrics, args.kv_store)
dllogger_data = dict(starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score)))
dllogger_data.update(
starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
)
global_metrics.update_dict(dllogger_data)
for percentile in [50, 90, 95, 99, 100]:
metric_name = 'val.latency_{}'.format(percentile)
dllogger_data[metric_name] = np.percentile(durations, percentile)
global_metrics.update_metric(metric_name, durations)
dllogger.log(step=(epoch,), data=dllogger_data)
2019-10-21 19:20:40 +02:00
else:
raise ValueError('Wrong mode')
2019-10-21 19:20:40 +02:00
mx.nd.waitall()
2021-04-07 17:46:50 +02:00
dllogger.log(tuple(), data=global_metrics.get())