DeepLearningExamples/TensorFlow/Segmentation/UNet_3D_Medical/runtime/arguments.py
Przemek Strzelczyk 79d4ced0be Adding 3DUnet/TF
2020-07-04 03:28:33 +02:00

49 lines
2.3 KiB
Python

# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
PARSER = argparse.ArgumentParser(description="UNet-3D")
# Estimator flags
PARSER.add_argument('--model_dir', required=True, type=str)
PARSER.add_argument('--exec_mode', choices=['train', 'evaluate', 'train_and_evaluate',
'predict', 'debug_train', 'debug_predict'], type=str)
# Training flags
PARSER.add_argument('--benchmark', dest='benchmark', action='store_true', default=False)
PARSER.add_argument('--max_steps', default=16000, type=int)
PARSER.add_argument('--learning_rate', default=0.0002, type=float)
PARSER.add_argument('--log_every', default=100, type=int)
PARSER.add_argument('--log_dir', type=str)
PARSER.add_argument('--loss', choices=['dice', 'ce', 'dice+ce'], default='dice+ce', type=str)
PARSER.add_argument('--warmup_steps', default=40, type=int)
PARSER.add_argument('--normalization', choices=['instancenorm', 'batchnorm', 'groupnorm'],
default='instancenorm', type=str)
PARSER.add_argument('--include_background', dest='include_background', action='store_true', default=False)
PARSER.add_argument('--resume_training', dest='resume_training', action='store_true', default=False)
# Augmentations
PARSER.add_argument('--augment', dest='augment', action='store_true', default=False)
# Dataset flags
PARSER.add_argument('--data_dir', required=True, type=str)
PARSER.add_argument('--batch_size', default=1, type=int)
PARSER.add_argument('--fold', default=0, type=int)
PARSER.add_argument('--num_folds', default=5, type=int)
# Tensorflow configuration flags
PARSER.add_argument('--use_amp', '--amp', dest='use_amp', action='store_true', default=False)
PARSER.add_argument('--use_xla', '--xla', dest='use_xla', action='store_true', default=False)