DeepLearningExamples/TensorFlow/Classification/ConvNets/triton/process_dataset.py
2021-04-20 13:50:41 +02:00

128 lines
5.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
from pathlib import Path
from typing import Tuple, Dict, List
from PIL import Image
from tqdm import tqdm
DATASETS_DIR = os.environ.get("DATASETS_DIR", None)
IMAGENET_DIRNAME = "imagenet"
IMAGE_ARCHIVE_FILENAME = "ILSVRC2012_img_val.tar"
DEVKIT_ARCHIVE_FILENAME = "ILSVRC2012_devkit_t12.tar.gz"
LABELS_REL_PATH = "ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt"
META_REL_PATH = "ILSVRC2012_devkit_t12/data/meta.mat"
TARGET_SIZE = (224, 224) # (width, height)
_RESIZE_MIN = 256 # resize preserving aspect ratio to where this is minimal size
def parse_meta_mat(metafile) -> Dict[int, str]:
import scipy.io
meta = scipy.io.loadmat(metafile, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
idcs, wnids = list(zip(*meta))[:2]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
return idx_to_wnid
def _process_image(image_file, target_size):
image = Image.open(image_file)
original_size = image.size
# scale image to size where minimal size is _RESIZE_MIN
scale_factor = max(_RESIZE_MIN / original_size[0], _RESIZE_MIN / original_size[1])
resize_to = int(original_size[0] * scale_factor), int(original_size[1] * scale_factor)
resized_image = image.resize(resize_to)
# central crop of image to target_size
left, upper = (resize_to[0] - target_size[0]) // 2, (resize_to[1] - target_size[1]) // 2
cropped_image = resized_image.crop((left, upper, left + target_size[0], upper + target_size[1]))
return cropped_image
def main():
import argparse
parser = argparse.ArgumentParser(description="short_description")
parser.add_argument(
"--dataset-dir",
help="Path to dataset directory where imagenet archives are stored and processed files will be saved.",
required=False,
default=DATASETS_DIR,
)
parser.add_argument(
"--target-size",
help="Size of target image. Format it as <width>,<height>.",
required=False,
default=",".join(map(str, TARGET_SIZE)),
)
args = parser.parse_args()
if args.dataset_dir is None:
raise ValueError(
"Please set $DATASETS_DIR env variable to point dataset dir with original dataset archives "
"and where processed files should be stored. Alternatively provide --dataset-dir CLI argument"
)
datasets_dir = Path(args.dataset_dir)
target_size = tuple(map(int, args.target_size.split(",")))
image_archive_path = datasets_dir / IMAGE_ARCHIVE_FILENAME
if not image_archive_path.exists():
raise RuntimeError(
f"There should be {IMAGE_ARCHIVE_FILENAME} file in {datasets_dir}."
f"You need to download the dataset from http://www.image-net.org/download."
)
devkit_archive_path = datasets_dir / DEVKIT_ARCHIVE_FILENAME
if not devkit_archive_path.exists():
raise RuntimeError(
f"There should be {DEVKIT_ARCHIVE_FILENAME} file in {datasets_dir}."
f"You need to download the dataset from http://www.image-net.org/download."
)
with tarfile.open(devkit_archive_path, mode="r") as devkit_archive_file:
labels_file = devkit_archive_file.extractfile(LABELS_REL_PATH)
labels = list(map(int, labels_file.readlines()))
# map validation labels (idxes from LABELS_REL_PATH) into WNID compatible with training set
meta_file = devkit_archive_file.extractfile(META_REL_PATH)
idx_to_wnid = parse_meta_mat(meta_file)
labels_wnid = [idx_to_wnid[idx] for idx in labels]
# remap WNID into index in sorted list of all WNIDs - this is how network outputs class
available_wnids = sorted(set(labels_wnid))
wnid_to_newidx = {wnid: new_cls for new_cls, wnid in enumerate(available_wnids)}
labels = [wnid_to_newidx[wnid] for wnid in labels_wnid]
output_dir = datasets_dir / IMAGENET_DIRNAME
with tarfile.open(image_archive_path, mode="r") as image_archive_file:
image_rel_paths = sorted(image_archive_file.getnames())
for cls, image_rel_path in tqdm(zip(labels, image_rel_paths), total=len(image_rel_paths)):
output_path = output_dir / str(cls) / image_rel_path
original_image_file = image_archive_file.extractfile(image_rel_path)
processed_image = _process_image(original_image_file, target_size)
output_path.parent.mkdir(parents=True, exist_ok=True)
processed_image.save(output_path.as_posix())
if __name__ == "__main__":
main()