48 lines
2.3 KiB
Python
48 lines
2.3 KiB
Python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
|
|
PARSER = argparse.ArgumentParser(description="UNet-3D")
|
|
|
|
# Estimator flags
|
|
PARSER.add_argument('--model_dir', required=True, type=str)
|
|
PARSER.add_argument('--exec_mode', choices=['train', 'evaluate', 'train_and_evaluate',
|
|
'predict', 'debug_train', 'debug_predict'], type=str)
|
|
|
|
# Training flags
|
|
PARSER.add_argument('--benchmark', dest='benchmark', action='store_true', default=False)
|
|
PARSER.add_argument('--max_steps', default=16000, type=int)
|
|
PARSER.add_argument('--learning_rate', default=0.0002, type=float)
|
|
PARSER.add_argument('--log_every', default=100, type=int)
|
|
PARSER.add_argument('--log_dir', type=str)
|
|
PARSER.add_argument('--loss', choices=['dice', 'ce', 'dice+ce'], default='dice+ce', type=str)
|
|
PARSER.add_argument('--warmup_steps', default=40, type=int)
|
|
PARSER.add_argument('--normalization', choices=['instancenorm', 'batchnorm', 'groupnorm'],
|
|
default='instancenorm', type=str)
|
|
PARSER.add_argument('--include_background', dest='include_background', action='store_true', default=False)
|
|
PARSER.add_argument('--resume_training', dest='resume_training', action='store_true', default=False)
|
|
|
|
# Augmentations
|
|
PARSER.add_argument('--augment', dest='augment', action='store_true', default=False)
|
|
|
|
# Dataset flags
|
|
PARSER.add_argument('--data_dir', required=True, type=str)
|
|
PARSER.add_argument('--batch_size', default=1, type=int)
|
|
PARSER.add_argument('--fold', default=0, type=int)
|
|
PARSER.add_argument('--num_folds', default=5, type=int)
|
|
|
|
# Tensorflow configuration flags
|
|
PARSER.add_argument('--use_amp', '--amp', dest='use_amp', action='store_true', default=False)
|
|
PARSER.add_argument('--use_xla', '--xla', dest='use_xla', action='store_true', default=False)
|