DeepLearningExamples/TensorFlow/Segmentation/UNet_Medical/utils/cmd_util.py

124 lines
4.5 KiB
Python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command line argument parsing"""
import argparse
from munch import Munch
PARSER = argparse.ArgumentParser(description="UNet-medical")
PARSER.add_argument('--exec_mode',
choices=['train', 'train_and_predict', 'predict', 'evaluate', 'train_and_evaluate'],
type=str,
default='train_and_evaluate',
help="""Execution mode of running the model""")
PARSER.add_argument('--model_dir',
type=str,
default='./results',
help="""Output directory for information related to the model""")
PARSER.add_argument('--data_dir',
type=str,
required=True,
help="""Input directory containing the dataset for training the model""")
PARSER.add_argument('--log_dir',
type=str,
default=None,
help="""Output directory for training logs""")
PARSER.add_argument('--batch_size',
type=int,
default=1,
help="""Size of each minibatch per GPU""")
PARSER.add_argument('--learning_rate',
type=float,
default=0.0001,
help="""Learning rate coefficient for AdamOptimizer""")
PARSER.add_argument('--crossvalidation_idx',
type=int,
default=None,
help="""Chosen fold for cross-validation. Use None to disable cross-validation""")
PARSER.add_argument('--max_steps',
type=int,
default=1000,
help="""Maximum number of steps (batches) used for training""")
PARSER.add_argument('--weight_decay',
type=float,
default=0.0005,
help="""Weight decay coefficient""")
PARSER.add_argument('--log_every',
type=int,
default=100,
help="""Log performance every n steps""")
PARSER.add_argument('--warmup_steps',
type=int,
default=200,
help="""Number of warmup steps""")
PARSER.add_argument('--seed',
type=int,
default=0,
help="""Random seed""")
PARSER.add_argument('--augment', dest='augment', action='store_true',
help="""Perform data augmentation during training""")
PARSER.add_argument('--no-augment', dest='augment', action='store_false')
PARSER.set_defaults(augment=False)
PARSER.add_argument('--benchmark', dest='benchmark', action='store_true',
help="""Collect performance metrics during training""")
PARSER.add_argument('--no-benchmark', dest='benchmark', action='store_false')
PARSER.set_defaults(augment=False)
PARSER.add_argument('--use_amp', dest='use_amp', action='store_true',
help="""Train using TF-AMP""")
PARSER.set_defaults(use_amp=False)
PARSER.add_argument('--use_xla', dest='use_xla', action='store_true',
help="""Train using XLA""")
PARSER.set_defaults(use_amp=False)
PARSER.add_argument('--use_trt', dest='use_trt', action='store_true',
help="""Use TF-TRT""")
PARSER.set_defaults(use_trt=False)
def _cmd_params(flags):
return Munch({
'exec_mode': flags.exec_mode,
'model_dir': flags.model_dir,
'data_dir': flags.data_dir,
'log_dir': flags.log_dir,
'batch_size': flags.batch_size,
'learning_rate': flags.learning_rate,
'crossvalidation_idx': flags.crossvalidation_idx,
'max_steps': flags.max_steps,
'weight_decay': flags.weight_decay,
'log_every': flags.log_every,
'warmup_steps': flags.warmup_steps,
'augment': flags.augment,
'benchmark': flags.benchmark,
'seed': flags.seed,
'use_amp': flags.use_amp,
'use_trt': flags.use_trt,
'use_xla': flags.use_xla,
})