506 lines
16 KiB
Python
506 lines
16 KiB
Python
# Copyright (c) 2018-2019, NVIDIA CORPORATION
|
|
# Copyright (c) 2017- Facebook, Inc
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
import torchvision.datasets as datasets
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from functools import partial
|
|
|
|
DATA_BACKEND_CHOICES = ["pytorch", "syntetic"]
|
|
try:
|
|
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
|
|
from nvidia.dali.pipeline import Pipeline
|
|
import nvidia.dali.ops as ops
|
|
import nvidia.dali.types as types
|
|
|
|
DATA_BACKEND_CHOICES.append("dali-gpu")
|
|
DATA_BACKEND_CHOICES.append("dali-cpu")
|
|
except ImportError:
|
|
print(
|
|
"Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
|
|
)
|
|
|
|
|
|
def load_jpeg_from_file(path, cuda=True, fp16=False):
|
|
img_transforms = transforms.Compose(
|
|
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
|
|
)
|
|
|
|
img = img_transforms(Image.open(path))
|
|
with torch.no_grad():
|
|
# mean and std are not multiplied by 255 as they are in training script
|
|
# torch dataloader reads data into bytes whereas loading directly
|
|
# through PIL creates a tensor with floats in [0,1] range
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
|
|
|
if cuda:
|
|
mean = mean.cuda()
|
|
std = std.cuda()
|
|
img = img.cuda()
|
|
if fp16:
|
|
mean = mean.half()
|
|
std = std.half()
|
|
img = img.half()
|
|
else:
|
|
img = img.float()
|
|
|
|
input = img.unsqueeze(0).sub_(mean).div_(std)
|
|
|
|
return input
|
|
|
|
|
|
class HybridTrainPipe(Pipeline):
|
|
def __init__(
|
|
self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False
|
|
):
|
|
super(HybridTrainPipe, self).__init__(
|
|
batch_size, num_threads, device_id, seed=12 + device_id
|
|
)
|
|
if torch.distributed.is_initialized():
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
self.input = ops.FileReader(
|
|
file_root=data_dir,
|
|
shard_id=rank,
|
|
num_shards=world_size,
|
|
random_shuffle=True,
|
|
)
|
|
|
|
if dali_cpu:
|
|
dali_device = "cpu"
|
|
self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
|
|
else:
|
|
dali_device = "gpu"
|
|
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
|
|
# without additional reallocations
|
|
self.decode = ops.ImageDecoder(
|
|
device="mixed",
|
|
output_type=types.RGB,
|
|
device_memory_padding=211025920,
|
|
host_memory_padding=140544512,
|
|
)
|
|
|
|
self.res = ops.RandomResizedCrop(
|
|
device=dali_device,
|
|
size=[crop, crop],
|
|
interp_type=types.INTERP_LINEAR,
|
|
random_aspect_ratio=[0.75, 4.0 / 3.0],
|
|
random_area=[0.08, 1.0],
|
|
num_attempts=100,
|
|
)
|
|
|
|
self.cmnp = ops.CropMirrorNormalize(
|
|
device="gpu",
|
|
output_dtype=types.FLOAT,
|
|
output_layout=types.NCHW,
|
|
crop=(crop, crop),
|
|
image_type=types.RGB,
|
|
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
|
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
|
|
)
|
|
self.coin = ops.CoinFlip(probability=0.5)
|
|
|
|
def define_graph(self):
|
|
rng = self.coin()
|
|
self.jpegs, self.labels = self.input(name="Reader")
|
|
images = self.decode(self.jpegs)
|
|
images = self.res(images)
|
|
output = self.cmnp(images.gpu(), mirror=rng)
|
|
return [output, self.labels]
|
|
|
|
|
|
class HybridValPipe(Pipeline):
|
|
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size):
|
|
super(HybridValPipe, self).__init__(
|
|
batch_size, num_threads, device_id, seed=12 + device_id
|
|
)
|
|
if torch.distributed.is_initialized():
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
self.input = ops.FileReader(
|
|
file_root=data_dir,
|
|
shard_id=rank,
|
|
num_shards=world_size,
|
|
random_shuffle=False,
|
|
)
|
|
|
|
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
|
|
self.res = ops.Resize(device="gpu", resize_shorter=size)
|
|
self.cmnp = ops.CropMirrorNormalize(
|
|
device="gpu",
|
|
output_dtype=types.FLOAT,
|
|
output_layout=types.NCHW,
|
|
crop=(crop, crop),
|
|
image_type=types.RGB,
|
|
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
|
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
|
|
)
|
|
|
|
def define_graph(self):
|
|
self.jpegs, self.labels = self.input(name="Reader")
|
|
images = self.decode(self.jpegs)
|
|
images = self.res(images)
|
|
output = self.cmnp(images)
|
|
return [output, self.labels]
|
|
|
|
|
|
class DALIWrapper(object):
|
|
def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
|
|
for data in dalipipeline:
|
|
input = data[0]["data"].contiguous(memory_format=memory_format)
|
|
target = torch.reshape(data[0]["label"], [-1]).cuda().long()
|
|
if one_hot:
|
|
target = expand(num_classes, torch.float, target)
|
|
yield input, target
|
|
dalipipeline.reset()
|
|
|
|
def __init__(self, dalipipeline, num_classes, one_hot, memory_format):
|
|
self.dalipipeline = dalipipeline
|
|
self.num_classes = num_classes
|
|
self.one_hot = one_hot
|
|
self.memory_format = memory_format
|
|
|
|
def __iter__(self):
|
|
return DALIWrapper.gen_wrapper(
|
|
self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
|
|
)
|
|
|
|
|
|
def get_dali_train_loader(dali_cpu=False):
|
|
def gdtl(
|
|
data_path,
|
|
batch_size,
|
|
num_classes,
|
|
one_hot,
|
|
start_epoch=0,
|
|
workers=5,
|
|
_worker_init_fn=None,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
if torch.distributed.is_initialized():
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
traindir = os.path.join(data_path, "train")
|
|
|
|
pipe = HybridTrainPipe(
|
|
batch_size=batch_size,
|
|
num_threads=workers,
|
|
device_id=rank % torch.cuda.device_count(),
|
|
data_dir=traindir,
|
|
crop=224,
|
|
dali_cpu=dali_cpu,
|
|
)
|
|
|
|
pipe.build()
|
|
train_loader = DALIClassificationIterator(
|
|
pipe, size=int(pipe.epoch_size("Reader") / world_size)
|
|
)
|
|
|
|
return (
|
|
DALIWrapper(train_loader, num_classes, one_hot, memory_format),
|
|
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
|
|
)
|
|
|
|
return gdtl
|
|
|
|
|
|
def get_dali_val_loader():
|
|
def gdvl(
|
|
data_path,
|
|
batch_size,
|
|
num_classes,
|
|
one_hot,
|
|
workers=5,
|
|
_worker_init_fn=None,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
if torch.distributed.is_initialized():
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
else:
|
|
rank = 0
|
|
world_size = 1
|
|
|
|
valdir = os.path.join(data_path, "val")
|
|
|
|
pipe = HybridValPipe(
|
|
batch_size=batch_size,
|
|
num_threads=workers,
|
|
device_id=rank % torch.cuda.device_count(),
|
|
data_dir=valdir,
|
|
crop=224,
|
|
size=256,
|
|
)
|
|
|
|
pipe.build()
|
|
val_loader = DALIClassificationIterator(
|
|
pipe, size=int(pipe.epoch_size("Reader") / world_size)
|
|
)
|
|
|
|
return (
|
|
DALIWrapper(val_loader, num_classes, one_hot, memory_format),
|
|
int(pipe.epoch_size("Reader") / (world_size * batch_size)),
|
|
)
|
|
|
|
return gdvl
|
|
|
|
|
|
def fast_collate(memory_format, batch):
|
|
imgs = [img[0] for img in batch]
|
|
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
|
|
w = imgs[0].size[0]
|
|
h = imgs[0].size[1]
|
|
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous(
|
|
memory_format=memory_format
|
|
)
|
|
for i, img in enumerate(imgs):
|
|
nump_array = np.asarray(img, dtype=np.uint8)
|
|
if nump_array.ndim < 3:
|
|
nump_array = np.expand_dims(nump_array, axis=-1)
|
|
nump_array = np.rollaxis(nump_array, 2)
|
|
|
|
tensor[i] += torch.from_numpy(nump_array)
|
|
|
|
return tensor, targets
|
|
|
|
|
|
def expand(num_classes, dtype, tensor):
|
|
e = torch.zeros(
|
|
tensor.size(0), num_classes, dtype=dtype, device=torch.device("cuda")
|
|
)
|
|
e = e.scatter(1, tensor.unsqueeze(1), 1.0)
|
|
return e
|
|
|
|
|
|
class PrefetchedWrapper(object):
|
|
def prefetched_loader(loader, num_classes, fp16, one_hot):
|
|
mean = (
|
|
torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
|
|
.cuda()
|
|
.view(1, 3, 1, 1)
|
|
)
|
|
std = (
|
|
torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
|
|
.cuda()
|
|
.view(1, 3, 1, 1)
|
|
)
|
|
if fp16:
|
|
mean = mean.half()
|
|
std = std.half()
|
|
|
|
stream = torch.cuda.Stream()
|
|
first = True
|
|
|
|
for next_input, next_target in loader:
|
|
with torch.cuda.stream(stream):
|
|
next_input = next_input.cuda(non_blocking=True)
|
|
next_target = next_target.cuda(non_blocking=True)
|
|
if fp16:
|
|
next_input = next_input.half()
|
|
if one_hot:
|
|
next_target = expand(num_classes, torch.half, next_target)
|
|
else:
|
|
next_input = next_input.float()
|
|
if one_hot:
|
|
next_target = expand(num_classes, torch.float, next_target)
|
|
|
|
next_input = next_input.sub_(mean).div_(std)
|
|
|
|
if not first:
|
|
yield input, target
|
|
else:
|
|
first = False
|
|
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
input = next_input
|
|
target = next_target
|
|
|
|
yield input, target
|
|
|
|
def __init__(self, dataloader, start_epoch, num_classes, fp16, one_hot):
|
|
self.dataloader = dataloader
|
|
self.fp16 = fp16
|
|
self.epoch = start_epoch
|
|
self.one_hot = one_hot
|
|
self.num_classes = num_classes
|
|
|
|
def __iter__(self):
|
|
if self.dataloader.sampler is not None and isinstance(
|
|
self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler
|
|
):
|
|
|
|
self.dataloader.sampler.set_epoch(self.epoch)
|
|
self.epoch += 1
|
|
return PrefetchedWrapper.prefetched_loader(
|
|
self.dataloader, self.num_classes, self.fp16, self.one_hot
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.dataloader)
|
|
|
|
|
|
def get_pytorch_train_loader(
|
|
data_path,
|
|
batch_size,
|
|
num_classes,
|
|
one_hot,
|
|
start_epoch=0,
|
|
workers=5,
|
|
_worker_init_fn=None,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
traindir = os.path.join(data_path, "train")
|
|
train_dataset = datasets.ImageFolder(
|
|
traindir,
|
|
transforms.Compose(
|
|
[transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip()]
|
|
),
|
|
)
|
|
|
|
if torch.distributed.is_initialized():
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
|
else:
|
|
train_sampler = None
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=(train_sampler is None),
|
|
num_workers=workers,
|
|
worker_init_fn=_worker_init_fn,
|
|
pin_memory=True,
|
|
sampler=train_sampler,
|
|
collate_fn=partial(fast_collate, memory_format),
|
|
drop_last=True,
|
|
)
|
|
|
|
return (
|
|
PrefetchedWrapper(train_loader, start_epoch, num_classes, fp16, one_hot),
|
|
len(train_loader),
|
|
)
|
|
|
|
|
|
def get_pytorch_val_loader(
|
|
data_path,
|
|
batch_size,
|
|
num_classes,
|
|
one_hot,
|
|
workers=5,
|
|
_worker_init_fn=None,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
valdir = os.path.join(data_path, "val")
|
|
val_dataset = datasets.ImageFolder(
|
|
valdir, transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
|
|
)
|
|
|
|
if torch.distributed.is_initialized():
|
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
|
|
else:
|
|
val_sampler = None
|
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_dataset,
|
|
sampler=val_sampler,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=workers,
|
|
worker_init_fn=_worker_init_fn,
|
|
pin_memory=True,
|
|
collate_fn=partial(fast_collate, memory_format),
|
|
)
|
|
|
|
return PrefetchedWrapper(val_loader, 0, num_classes, fp16, one_hot), len(val_loader)
|
|
|
|
|
|
class SynteticDataLoader(object):
|
|
def __init__(
|
|
self,
|
|
fp16,
|
|
batch_size,
|
|
num_classes,
|
|
num_channels,
|
|
height,
|
|
width,
|
|
one_hot,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
input_data = (
|
|
torch.empty(batch_size, num_channels, height, width).contiguous(memory_format=memory_format).cuda().normal_(0, 1.0)
|
|
)
|
|
if one_hot:
|
|
input_target = torch.empty(batch_size, num_classes).cuda()
|
|
input_target[:, 0] = 1.0
|
|
else:
|
|
input_target = torch.randint(0, num_classes, (batch_size,))
|
|
input_target = input_target.cuda()
|
|
if fp16:
|
|
input_data = input_data.half()
|
|
|
|
self.input_data = input_data
|
|
self.input_target = input_target
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
yield self.input_data, self.input_target
|
|
|
|
|
|
def get_syntetic_loader(
|
|
data_path,
|
|
batch_size,
|
|
num_classes,
|
|
one_hot,
|
|
start_epoch=0,
|
|
workers=None,
|
|
_worker_init_fn=None,
|
|
fp16=False,
|
|
memory_format=torch.contiguous_format,
|
|
):
|
|
return SynteticDataLoader(fp16, batch_size, num_classes, 3, 224, 224, one_hot, memory_format=memory_format), -1
|