2019-01-23 17:14:51 +01:00
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
2021-06-30 13:52:43 +02:00
from packaging . version import Version
2019-01-23 17:14:51 +01:00
from nvidia import dali
from nvidia . dali . pipeline import Pipeline
import nvidia . dali . ops as ops
import nvidia . dali . types as types
from nvidia . dali . plugin . mxnet import DALIClassificationIterator
2019-10-21 19:20:40 +02:00
import horovod . mxnet as hvd
2019-01-23 17:14:51 +01:00
def add_dali_args ( parser ) :
2019-10-21 19:20:40 +02:00
group = parser . add_argument_group ( ' DALI data backend ' , ' entire group applies only to dali data backend ' )
group . add_argument ( ' --dali-separ-val ' , action = ' store_true ' ,
2019-01-23 17:14:51 +01:00
help = ' each process will perform independent validation on whole val-set ' )
2021-06-30 13:52:43 +02:00
group . add_argument ( ' --dali-threads ' , type = int , default = 4 , help = " number of threads " + \
2019-01-23 17:14:51 +01:00
" per GPU for DALI " )
2019-10-21 19:20:40 +02:00
group . add_argument ( ' --dali-validation-threads ' , type = int , default = 10 , help = " number of threads " + \
2019-01-23 17:14:51 +01:00
" per GPU for DALI for validation " )
2019-10-21 19:20:40 +02:00
group . add_argument ( ' --dali-prefetch-queue ' , type = int , default = 2 , help = " DALI prefetch queue depth " )
group . add_argument ( ' --dali-nvjpeg-memory-padding ' , type = int , default = 64 , help = " Memory padding value for nvJPEG (in MB) " )
group . add_argument ( ' --dali-fuse-decoder ' , type = int , default = 1 , help = " 0 or 1 whether to fuse decoder or not " )
2021-06-30 13:52:43 +02:00
group . add_argument ( ' --dali-nvjpeg-width-hint ' , type = int , default = 5980 , help = " Width hint value for nvJPEG (in pixels) " )
group . add_argument ( ' --dali-nvjpeg-height-hint ' , type = int , default = 6430 , help = " Height hint value for nvJPEG (in pixels) " )
2021-09-02 16:35:20 +02:00
group . add_argument ( ' --dali-dont-use-mmap ' , default = False , action = ' store_true ' , help = " Use plain I/O instead of MMAP for datasets " )
2019-01-23 17:14:51 +01:00
return parser
class HybridTrainPipe ( Pipeline ) :
2019-10-21 19:20:40 +02:00
def __init__ ( self , args , batch_size , num_threads , device_id , rec_path , idx_path ,
shard_id , num_shards , crop_shape , nvjpeg_padding , prefetch_queue = 3 ,
2021-06-30 13:52:43 +02:00
output_layout = types . NCHW , pad_output = True , dtype = ' float16 ' , dali_cpu = False ,
nvjpeg_width_hint = 5980 , nvjpeg_height_hint = 6430 ,
) :
2019-10-21 19:20:40 +02:00
super ( HybridTrainPipe , self ) . __init__ ( batch_size , num_threads , device_id , seed = 12 + device_id , prefetch_queue_depth = prefetch_queue )
self . input = ops . MXNetReader ( path = [ rec_path ] , index_path = [ idx_path ] ,
2021-09-02 16:35:20 +02:00
random_shuffle = True , shard_id = shard_id , num_shards = num_shards ,
dont_use_mmap = args . dali_dont_use_mmap )
2019-01-23 17:14:51 +01:00
2019-10-21 19:20:40 +02:00
if dali_cpu :
dali_device = " cpu "
2021-04-07 17:46:50 +02:00
decoder_device = " cpu "
2019-10-21 19:20:40 +02:00
else :
dali_device = " gpu "
2021-04-07 17:46:50 +02:00
decoder_device = " mixed "
2021-06-30 13:52:43 +02:00
dali_kwargs_fallback = { }
if Version ( dali . __version__ ) > = Version ( " 1.2.0 " ) :
dali_kwargs_fallback = {
" preallocate_width_hint " : nvjpeg_width_hint ,
" preallocate_height_hint " : nvjpeg_height_hint ,
}
2021-04-07 17:46:50 +02:00
if args . dali_fuse_decoder :
self . decode = ops . ImageDecoderRandomCrop ( device = decoder_device , output_type = types . RGB ,
2021-06-30 13:52:43 +02:00
device_memory_padding = nvjpeg_padding ,
host_memory_padding = nvjpeg_padding ,
* * dali_kwargs_fallback )
2021-04-07 17:46:50 +02:00
else :
self . decode = ops . ImageDecoder ( device = decoder_device , output_type = types . RGB ,
2021-06-30 13:52:43 +02:00
device_memory_padding = nvjpeg_padding ,
host_memory_padding = nvjpeg_padding ,
* * dali_kwargs_fallback )
2019-10-21 19:20:40 +02:00
if args . dali_fuse_decoder :
self . resize = ops . Resize ( device = dali_device , resize_x = crop_shape [ 1 ] , resize_y = crop_shape [ 0 ] )
else :
self . resize = ops . RandomResizedCrop ( device = dali_device , size = crop_shape )
self . cmnp = ops . CropMirrorNormalize ( device = " gpu " ,
output_dtype = types . FLOAT16 if dtype == ' float16 ' else types . FLOAT ,
output_layout = output_layout , crop = crop_shape , pad_output = pad_output ,
image_type = types . RGB , mean = args . rgb_mean , std = args . rgb_std )
self . coin = ops . CoinFlip ( probability = 0.5 )
2019-01-23 17:14:51 +01:00
def define_graph ( self ) :
rng = self . coin ( )
2019-10-21 19:20:40 +02:00
self . jpegs , self . labels = self . input ( name = " Reader " )
2019-01-23 17:14:51 +01:00
images = self . decode ( self . jpegs )
2019-10-21 19:20:40 +02:00
images = self . resize ( images )
output = self . cmnp ( images . gpu ( ) , mirror = rng )
2019-01-23 17:14:51 +01:00
return [ output , self . labels ]
class HybridValPipe ( Pipeline ) :
2019-10-21 19:20:40 +02:00
def __init__ ( self , args , batch_size , num_threads , device_id , rec_path , idx_path ,
shard_id , num_shards , crop_shape , nvjpeg_padding , prefetch_queue = 3 , resize_shp = None ,
2021-06-30 13:52:43 +02:00
output_layout = types . NCHW , pad_output = True , dtype = ' float16 ' , dali_cpu = False ,
nvjpeg_width_hint = 5980 , nvjpeg_height_hint = 6430 ) :
2019-10-21 19:20:40 +02:00
super ( HybridValPipe , self ) . __init__ ( batch_size , num_threads , device_id , seed = 12 + device_id , prefetch_queue_depth = prefetch_queue )
self . input = ops . MXNetReader ( path = [ rec_path ] , index_path = [ idx_path ] ,
2021-09-02 16:35:20 +02:00
random_shuffle = False , shard_id = shard_id , num_shards = num_shards ,
dont_use_mmap = args . dali_dont_use_mmap )
2019-10-21 19:20:40 +02:00
if dali_cpu :
dali_device = " cpu "
2021-04-07 17:46:50 +02:00
decoder_device = " cpu "
2019-10-21 19:20:40 +02:00
else :
dali_device = " gpu "
2021-04-07 17:46:50 +02:00
decoder_device = " mixed "
2021-06-30 13:52:43 +02:00
dali_kwargs_fallback = { }
if Version ( dali . __version__ ) > = Version ( " 1.2.0 " ) :
dali_kwargs_fallback = {
" preallocate_width_hint " : nvjpeg_width_hint ,
" preallocate_height_hint " : nvjpeg_height_hint
}
2021-04-07 17:46:50 +02:00
self . decode = ops . ImageDecoder ( device = decoder_device , output_type = types . RGB ,
device_memory_padding = nvjpeg_padding ,
2021-06-30 13:52:43 +02:00
host_memory_padding = nvjpeg_padding ,
* * dali_kwargs_fallback )
2019-10-21 19:20:40 +02:00
self . resize = ops . Resize ( device = dali_device , resize_shorter = resize_shp ) if resize_shp else None
self . cmnp = ops . CropMirrorNormalize ( device = " gpu " ,
output_dtype = types . FLOAT16 if dtype == ' float16 ' else types . FLOAT ,
output_layout = output_layout , crop = crop_shape , pad_output = pad_output ,
image_type = types . RGB , mean = args . rgb_mean , std = args . rgb_std )
2019-01-23 17:14:51 +01:00
def define_graph ( self ) :
2019-10-21 19:20:40 +02:00
self . jpegs , self . labels = self . input ( name = " Reader " )
2019-01-23 17:14:51 +01:00
images = self . decode ( self . jpegs )
if self . resize :
images = self . resize ( images )
2019-10-21 19:20:40 +02:00
output = self . cmnp ( images . gpu ( ) )
2019-01-23 17:14:51 +01:00
return [ output , self . labels ]
2019-10-21 19:20:40 +02:00
def get_rec_iter ( args , kv = None , dali_cpu = False ) :
gpus = args . gpus
2019-01-23 17:14:51 +01:00
num_threads = args . dali_threads
2019-10-21 19:20:40 +02:00
num_validation_threads = args . dali_validation_threads
pad_output = ( args . image_shape [ 0 ] == 4 )
2019-01-23 17:14:51 +01:00
# the input_layout w.r.t. the model is the output_layout of the image pipeline
output_layout = types . NHWC if args . input_layout == ' NHWC ' else types . NCHW
2019-10-21 19:20:40 +02:00
if ' horovod ' in args . kv_store :
rank = hvd . rank ( )
nWrk = hvd . size ( )
else :
rank = kv . rank if kv else 0
nWrk = kv . num_workers if kv else 1
2019-01-23 17:14:51 +01:00
2019-10-21 19:20:40 +02:00
batch_size = args . batch_size / / nWrk / / len ( gpus )
trainpipes = [ HybridTrainPipe ( args = args ,
batch_size = batch_size ,
2019-01-23 17:14:51 +01:00
num_threads = num_threads ,
device_id = gpu_id ,
rec_path = args . data_train ,
idx_path = args . data_train_idx ,
shard_id = gpus . index ( gpu_id ) + len ( gpus ) * rank ,
num_shards = len ( gpus ) * nWrk ,
2019-10-21 19:20:40 +02:00
crop_shape = args . image_shape [ 1 : ] ,
2019-01-23 17:14:51 +01:00
output_layout = output_layout ,
dtype = args . dtype ,
2019-10-21 19:20:40 +02:00
pad_output = pad_output ,
dali_cpu = dali_cpu ,
2019-01-23 17:14:51 +01:00
nvjpeg_padding = args . dali_nvjpeg_memory_padding * 1024 * 1024 ,
2021-06-30 13:52:43 +02:00
prefetch_queue = args . dali_prefetch_queue ,
nvjpeg_width_hint = args . dali_nvjpeg_width_hint ,
nvjpeg_height_hint = args . dali_nvjpeg_height_hint ) for gpu_id in gpus ]
2019-01-23 17:14:51 +01:00
2019-10-21 19:20:40 +02:00
if args . data_val :
valpipes = [ HybridValPipe ( args = args ,
batch_size = batch_size ,
num_threads = num_validation_threads ,
device_id = gpu_id ,
rec_path = args . data_val ,
idx_path = args . data_val_idx ,
shard_id = 0 if args . dali_separ_val
else gpus . index ( gpu_id ) + len ( gpus ) * rank ,
num_shards = 1 if args . dali_separ_val else len ( gpus ) * nWrk ,
crop_shape = args . image_shape [ 1 : ] ,
resize_shp = args . data_val_resize ,
output_layout = output_layout ,
dtype = args . dtype ,
pad_output = pad_output ,
dali_cpu = dali_cpu ,
nvjpeg_padding = args . dali_nvjpeg_memory_padding * 1024 * 1024 ,
2021-06-30 13:52:43 +02:00
prefetch_queue = args . dali_prefetch_queue ,
nvjpeg_width_hint = args . dali_nvjpeg_width_hint ,
nvjpeg_height_hint = args . dali_nvjpeg_height_hint ) for gpu_id in gpus ] if args . data_val else None
2019-01-23 17:14:51 +01:00
trainpipes [ 0 ] . build ( )
if args . data_val :
valpipes [ 0 ] . build ( )
2019-10-21 19:20:40 +02:00
worker_val_examples = valpipes [ 0 ] . epoch_size ( " Reader " )
if not args . dali_separ_val :
worker_val_examples = worker_val_examples / / nWrk
if rank < valpipes [ 0 ] . epoch_size ( " Reader " ) % nWrk :
worker_val_examples + = 1
2019-01-23 17:14:51 +01:00
if args . num_examples < trainpipes [ 0 ] . epoch_size ( " Reader " ) :
warnings . warn ( " {} training examples will be used, although full training set contains {} examples " . format ( args . num_examples , trainpipes [ 0 ] . epoch_size ( " Reader " ) ) )
dali_train_iter = DALIClassificationIterator ( trainpipes , args . num_examples / / nWrk )
2019-10-21 19:20:40 +02:00
if args . data_val :
dali_val_iter = DALIClassificationIterator ( valpipes , worker_val_examples , fill_last_batch = False ) if args . data_val else None
else :
dali_val_iter = None
return dali_train_iter , dali_val_iter