diff --git a/TensorFlow2/Segmentation/UNet_Medical/README.md b/TensorFlow2/Segmentation/UNet_Medical/README.md index a5d05522..e22ad4f2 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/README.md +++ b/TensorFlow2/Segmentation/UNet_Medical/README.md @@ -231,20 +231,20 @@ For the specifics concerning training and inference, see the [Advanced](#advance This script will launch a training on a single fold and store the model’s checkpoint in the directory. - The script can be run directly by modifying flags if necessary, especially the number of GPUs, which is defined after the `-np` flag. Since the test volume does not have labels, 20% of the training data is used for validation in 5-fold cross-validation manner. The number of fold can be changed using `--crossvalidation_idx` with an integer in range 0-4. For example, to run with 4 GPUs using fold 1 use: + The script can be run directly by modifying flags if necessary, especially the number of GPUs, which is defined after the `-np` flag. Since the test volume does not have labels, 20% of the training data is used for validation in 5-fold cross-validation manner. The number of fold can be changed using `--fold` with an integer in range 0-4. For example, to run with 4 GPUs using fold 1 use: ```bash - horovodrun -np 4 python main.py --data_dir /data --model_dir /results --batch_size 1 --exec_mode train --crossvalidation_idx 1 --xla --amp + horovodrun -np 4 python main.py --data_dir /data --model_dir /results --batch_size 1 --exec_mode train --fold 1 --xla --amp ``` Training will result in a checkpoint file being written to `./results` on the host machine. 6. Start validation/evaluation. - The trained model can be evaluated by passing the `--exec_mode evaluate` flag. Since evaluation is carried out on a validation dataset, the `--crossvalidation_idx` parameter should be filled. For example: + The trained model can be evaluated by passing the `--exec_mode evaluate` flag. Since evaluation is carried out on a validation dataset, the `--fold` parameter should be filled. For example: ```bash - python main.py --data_dir /data --model_dir /results --batch_size 1 --exec_mode evaluate --crossvalidation_idx 0 --xla --amp + python main.py --data_dir /data --model_dir /results --batch_size 1 --exec_mode evaluate --fold 0 --xla --amp ``` Evaluation can also be triggered jointly after training by passing the `--exec_mode train_and_evaluate` flag. @@ -291,19 +291,20 @@ Other folders included in the root directory are: The complete list of the available parameters for the `main.py` script contains: * `--exec_mode`: Select the execution mode to run the model (default: `train`). Modes available: * `train` - trains model from scratch. - * `evaluate` - loads checkpoint (if available) and performs evaluation on validation subset (requires `--crossvalidation_idx` other than `None`). - * `train_and_evaluate` - trains model from scratch and performs validation at the end (requires `--crossvalidation_idx` other than `None`). + * `evaluate` - loads checkpoint (if available) and performs evaluation on validation subset (requires `--fold` other than `None`). + * `train_and_evaluate` - trains model from scratch and performs validation at the end (requires `--fold` other than `None`). * `predict` - loads checkpoint (if available) and runs inference on the test set. Stores the results in `--model_dir` directory. * `train_and_predict` - trains model from scratch and performs inference. * `--model_dir`: Set the output directory for information related to the model (default: `/results`). * `--log_dir`: Set the output directory for logs (default: None). * `--data_dir`: Set the input directory containing the dataset (default: `None`). * `--batch_size`: Size of each minibatch per GPU (default: `1`). -* `--crossvalidation_idx`: Selected fold for cross-validation (default: `None`). +* `--fold`: Selected fold for cross-validation (default: `None`). * `--max_steps`: Maximum number of steps (batches) for training (default: `1000`). * `--seed`: Set random seed for reproducibility (default: `0`). * `--weight_decay`: Weight decay coefficient (default: `0.0005`). * `--log_every`: Log performance every n steps (default: `100`). +* `--evaluate_every`: Evaluate every n steps (default: `0` - evaluate once at the end). * `--learning_rate`: Model’s learning rate (default: `0.0001`). * `--augment`: Enable data augmentation (default: `False`). * `--benchmark`: Enable performance benchmarking (default: `False`). If the flag is set, the script runs in a benchmark mode - each iteration is timed and the performance result (in images per second) is printed at the end. Works for both `train` and `predict` execution modes. @@ -324,8 +325,8 @@ usage: main.py [-h] [--exec_mode {train,train_and_predict,predict,evaluate,train_and_evaluate}] [--model_dir MODEL_DIR] --data_dir DATA_DIR [--log_dir LOG_DIR] [--batch_size BATCH_SIZE] [--learning_rate LEARNING_RATE] - [--crossvalidation_idx CROSSVALIDATION_IDX] - [--max_steps MAX_STEPS] [--weight_decay WEIGHT_DECAY] + [--fold FOLD] [--max_steps MAX_STEPS] + [--evaluate_every EVALUATE_EVERY] [--weight_decay WEIGHT_DECAY] [--log_every LOG_EVERY] [--warmup_steps WARMUP_STEPS] [--seed SEED] [--augment] [--benchmark] [--amp] [--xla] @@ -333,34 +334,39 @@ usage: main.py [-h] UNet-medical optional arguments: - -h, --help show this help message and exit - --exec_mode {train,train_and_predict,predict,evaluate,train_and_evaluate} - Execution mode of running the model - --model_dir MODEL_DIR - Output directory for information related to the model - --data_dir DATA_DIR Input directory containing the dataset for training - the model - --log_dir LOG_DIR Output directory for training logs - --batch_size BATCH_SIZE - Size of each minibatch per GPU - --learning_rate LEARNING_RATE - Learning rate coefficient for AdamOptimizer - --crossvalidation_idx CROSSVALIDATION_IDX - Chosen fold for cross-validation. Use None to disable - cross-validation - --max_steps MAX_STEPS - Maximum number of steps (batches) used for training - --weight_decay WEIGHT_DECAY - Weight decay coefficient - --log_every LOG_EVERY - Log performance every n steps - --warmup_steps WARMUP_STEPS - Number of warmup steps - --seed SEED Random seed - --augment Perform data augmentation during training - --benchmark Collect performance metrics during training - --amp Train using TF-AMP - --xla Train using XLA + -h, --help show this help message and exit + --exec_mode {train,train_and_predict,predict,evaluate,train_and_evaluate} + Execution mode of running the model + --model_dir MODEL_DIR + Output directory for information related to the model + --data_dir DATA_DIR Input directory containing the dataset for training + the model + --log_dir LOG_DIR Output directory for training logs + --batch_size BATCH_SIZE + Size of each minibatch per GPU + --learning_rate LEARNING_RATE + Learning rate coefficient for AdamOptimizer + --fold FOLD Chosen fold for cross-validation. Use None to disable + cross-validation + --max_steps MAX_STEPS + Maximum number of steps (batches) used for training + --weight_decay WEIGHT_DECAY + Weight decay coefficient + --log_every LOG_EVERY + Log performance every n steps + --evaluate_every EVALUATE_EVERY + Evaluate every n steps + --warmup_steps WARMUP_STEPS + Number of warmup steps + --seed SEED Random seed + --augment Perform data augmentation during training + --no-augment + --benchmark Collect performance metrics during training + --no-benchmark + --use_amp, --amp Train using TF-AMP + --use_xla, --xla Train using XLA + --use_trt Use TF-TRT + --resume_training Resume training from a checkpoint ``` @@ -420,7 +426,7 @@ horovodrun -np python main.py --data_dir /data [other parameter The main result of the training are checkpoints stored by default in `./results/` on the host machine, and in the `/results` in the container. This location can be controlled by the `--model_dir` command-line argument, if a different location was mounted while starting the container. In the case when the training is run in `train_and_predict` mode, the inference will take place after the training is finished, and inference results will be stored to the `/results` directory. -If the `--exec_mode train_and_evaluate` parameter was used, and if `--crossvalidation_idx` parameter is set to an integer value of {0, 1, 2, 3, 4}, the evaluation of the validation set takes place after the training is completed. The results of the evaluation will be printed to the console. +If the `--exec_mode train_and_evaluate` parameter was used, and if `--fold` parameter is set to an integer value of {0, 1, 2, 3, 4}, the evaluation of the validation set takes place after the training is completed. The results of the evaluation will be printed to the console. ### Inference process diff --git a/TensorFlow2/Segmentation/UNet_Medical/utils/data_loader.py b/TensorFlow2/Segmentation/UNet_Medical/data_loading/data_loader.py similarity index 100% rename from TensorFlow2/Segmentation/UNet_Medical/utils/data_loader.py rename to TensorFlow2/Segmentation/UNet_Medical/data_loading/data_loader.py diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_1GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_1GPU.sh index 2f64f682..569f87f6 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_1GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_1GPU.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP32 on 1 GPU and trains for 6400 iterations with batch_size 8. Usage: # bash unet_FP32_1GPU.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --log_dir $2/log.json \ No newline at end of file +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --fold 0 --augment --xla --log_dir $2/log.json \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_8GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_8GPU.sh index 5b48ab36..1dedf120 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_8GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_8GPU.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP32 on 8 GPUs and trains for 6400 iterations with batch_size 8. Usage: # bash unet_FP32_8GPU.sh -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --log_dir $2/log.json \ No newline at end of file +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --fold 0 --augment --xla --log_dir $2/log.json \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER.sh index 661f9cae..07d5dcc4 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP32 on 1 GPU for inference batch_size 1. Usage: # bash unet_INFER_FP32.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size 1 --exec_mode predict --xla +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size 1 --exec_mode predict --xla --fold 0 diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK.sh index 19e30627..34fa4717 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP32 on 1 GPU for inference benchmarking. Usage: # bash unet_INFER_BENCHMARK_FP32.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size $3 --exec_mode predict --benchmark --warmup_steps 200 --max_steps 600 --xla \ No newline at end of file +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size $3 --exec_mode predict --benchmark --warmup_steps 200 --max_steps 600 --xla --fold 0 \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK_TF-AMP.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK_TF-AMP.sh index 11f96a73..d3594028 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK_TF-AMP.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_BENCHMARK_TF-AMP.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP16 on 1 GPU for inference benchmarking. Usage: # bash unet_INFER_BENCHMARK_TF-AMP.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size $3 --exec_mode predict --benchmark --warmup_steps 200 --max_steps 600 --xla --amp \ No newline at end of file +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size $3 --exec_mode predict --benchmark --warmup_steps 200 --max_steps 600 --xla --amp --fold 0 \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_TF-AMP.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_TF-AMP.sh index d8d0f3f0..ea73bdc2 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_TF-AMP.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_INFER_TF-AMP.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP16 on 1 GPU for inference batch_size 1. Usage: # bash unet_INFER_TF-AMP.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size 1 --exec_mode predict --xla --amp +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --batch_size 1 --exec_mode predict --xla --amp --fold 0 diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_1GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_1GPU.sh index 77045575..d8a7eb82 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_1GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_1GPU.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP16 on 1 GPU and trains for 6400 iterations batch_size 8. Usage: # bash unet_TF-AMP_1GPU.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --amp --log_dir $2/log.json \ No newline at end of file +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --fold 0 --augment --xla --amp --log_dir $2/log.json \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_8GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_8GPU.sh index 6e7d815d..25a31b6b 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_8GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TF-AMP_8GPU.sh @@ -15,4 +15,4 @@ # This script launches U-Net run in FP16 on 8 GPUs and trains for 6400 iterations batch_size 8. Usage: # bash unet_TF-AMP_8GPU.sh -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --amp --log_dir $2/log.json \ No newline at end of file +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size 8 --exec_mode train_and_evaluate --fold 0 --augment --xla --amp --log_dir $2/log.json \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_1GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_1GPU.sh index 3a109a97..fed90f53 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_1GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_1GPU.sh @@ -16,9 +16,9 @@ # Usage: # bash unet_TRAIN_FP32_1GPU.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla > $2/log_FP32_1GPU_fold0.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 1 --augment --xla > $2/log_FP32_1GPU_fold1.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 2 --augment --xla > $2/log_FP32_1GPU_fold2.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 3 --augment --xla > $2/log_FP32_1GPU_fold3.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 4 --augment --xla > $2/log_FP32_1GPU_fold4.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 0 --augment --xla > $2/log_FP32_1GPU_fold0.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 1 --augment --xla > $2/log_FP32_1GPU_fold1.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 2 --augment --xla > $2/log_FP32_1GPU_fold2.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 3 --augment --xla > $2/log_FP32_1GPU_fold3.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 4 --augment --xla > $2/log_FP32_1GPU_fold4.txt python utils/parse_results.py --model_dir $2 --exec_mode convergence --env FP32_1GPU \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_8GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_8GPU.sh index e2577667..5d393b00 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_8GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_8GPU.sh @@ -16,9 +16,9 @@ # Usage: # bash unet_TRAIN_FP32_8GPU.sh -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla > $2/log_FP32_8GPU_fold0.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 1 --augment --xla > $2/log_FP32_8GPU_fold1.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 2 --augment --xla > $2/log_FP32_8GPU_fold2.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 3 --augment --xla > $2/log_FP32_8GPU_fold3.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 4 --augment --xla > $2/log_FP32_8GPU_fold4.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 0 --augment --xla > $2/log_FP32_8GPU_fold0.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 1 --augment --xla > $2/log_FP32_8GPU_fold1.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 2 --augment --xla > $2/log_FP32_8GPU_fold2.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 3 --augment --xla > $2/log_FP32_8GPU_fold3.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 4 --augment --xla > $2/log_FP32_8GPU_fold4.txt python utils/parse_results.py --model_dir $2 --exec_mode convergence --env FP32_8GPU \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_1GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_1GPU.sh index b7457a4e..c5aa736e 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_1GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_1GPU.sh @@ -16,9 +16,9 @@ # Usage: # bash unet_TRAIN_TF-AMP_1GPU.sh -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold0.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 1 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold1.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 2 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold2.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 3 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold3.txt -horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 4 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold4.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 0 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold0.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 1 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold1.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 2 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold2.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 3 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold3.txt +horovodrun -np 1 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 4 --augment --xla --amp > $2/log_TF-AMP_1GPU_fold4.txt python utils/parse_results.py --model_dir $2 --exec_mode convergence --env TF-AMP_1GPU \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_8GPU.sh b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_8GPU.sh index 96af3376..c4d12f1f 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_8GPU.sh +++ b/TensorFlow2/Segmentation/UNet_Medical/examples/unet_TRAIN_TF-AMP_8GPU.sh @@ -16,9 +16,9 @@ # Usage: # bash unet_TRAIN_TF-AMP_8GPU.sh -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 0 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold0.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 1 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold1.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 2 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold2.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 3 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold3.txt -horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --crossvalidation_idx 4 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold4.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 0 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold0.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 1 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold1.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 2 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold2.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 3 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold3.txt +horovodrun -np 8 python main.py --data_dir $1 --model_dir $2 --log_every 100 --max_steps 6400 --batch_size $3 --exec_mode train_and_evaluate --fold 4 --augment --xla --amp > $2/log_TF-AMP_8GPU_fold4.txt python utils/parse_results.py --model_dir $2 --exec_mode convergence --env TF-AMP_8GPU \ No newline at end of file diff --git a/TensorFlow2/Segmentation/UNet_Medical/main.py b/TensorFlow2/Segmentation/UNet_Medical/main.py index 6b12ed8b..ee31f897 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/main.py +++ b/TensorFlow2/Segmentation/UNet_Medical/main.py @@ -26,10 +26,10 @@ Example: import horovod.tensorflow as hvd from model.unet import Unet -from run import train, evaluate, predict -from utils.setup import get_logger, set_flags, prepare_model_dir -from utils.cmd_util import PARSER, parse_args -from utils.data_loader import Dataset +from runtime.run import train, evaluate, predict +from runtime.setup import get_logger, set_flags, prepare_model_dir +from runtime.arguments import PARSER, parse_args +from data_loading.data_loader import Dataset def main(): @@ -47,7 +47,7 @@ def main(): dataset = Dataset(data_dir=params.data_dir, batch_size=params.batch_size, - fold=params.crossvalidation_idx, + fold=params.fold, augment=params.augment, gpu_id=hvd.rank(), num_gpus=hvd.size(), diff --git a/TensorFlow2/Segmentation/UNet_Medical/utils/cmd_util.py b/TensorFlow2/Segmentation/UNet_Medical/runtime/arguments.py similarity index 94% rename from TensorFlow2/Segmentation/UNet_Medical/utils/cmd_util.py rename to TensorFlow2/Segmentation/UNet_Medical/runtime/arguments.py index a4c65d29..9d4a4d98 100755 --- a/TensorFlow2/Segmentation/UNet_Medical/utils/cmd_util.py +++ b/TensorFlow2/Segmentation/UNet_Medical/runtime/arguments.py @@ -49,7 +49,7 @@ PARSER.add_argument('--learning_rate', default=0.0001, help="""Learning rate coefficient for AdamOptimizer""") -PARSER.add_argument('--crossvalidation_idx', +PARSER.add_argument('--fold', type=int, default=None, help="""Chosen fold for cross-validation. Use None to disable cross-validation""") @@ -69,6 +69,11 @@ PARSER.add_argument('--log_every', default=100, help="""Log performance every n steps""") +PARSER.add_argument('--evaluate_every', + type=int, + default=0, + help="""Evaluate every n steps""") + PARSER.add_argument('--warmup_steps', type=int, default=200, @@ -110,10 +115,11 @@ def parse_args(flags): 'log_dir': flags.log_dir, 'batch_size': flags.batch_size, 'learning_rate': flags.learning_rate, - 'crossvalidation_idx': flags.crossvalidation_idx, + 'fold': flags.fold, 'max_steps': flags.max_steps, 'weight_decay': flags.weight_decay, 'log_every': flags.log_every, + 'evaluate_every': flags.evaluate_every, 'warmup_steps': flags.warmup_steps, 'augment': flags.augment, 'benchmark': flags.benchmark, diff --git a/TensorFlow2/Segmentation/UNet_Medical/utils/losses.py b/TensorFlow2/Segmentation/UNet_Medical/runtime/losses.py similarity index 100% rename from TensorFlow2/Segmentation/UNet_Medical/utils/losses.py rename to TensorFlow2/Segmentation/UNet_Medical/runtime/losses.py diff --git a/TensorFlow2/Segmentation/UNet_Medical/utils/parse_results.py b/TensorFlow2/Segmentation/UNet_Medical/runtime/parse_results.py similarity index 81% rename from TensorFlow2/Segmentation/UNet_Medical/utils/parse_results.py rename to TensorFlow2/Segmentation/UNet_Medical/runtime/parse_results.py index 9601df0c..d6b3738f 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/utils/parse_results.py +++ b/TensorFlow2/Segmentation/UNet_Medical/runtime/parse_results.py @@ -17,21 +17,21 @@ import numpy as np import argparse -def process_performance_stats(timestamps, params): - warmup_steps = params['warmup_steps'] - batch_size = params['batch_size'] - timestamps_ms = 1000 * timestamps[warmup_steps:] - timestamps_ms = timestamps_ms[timestamps_ms > 0] - latency_ms = timestamps_ms.mean() - std = timestamps_ms.std() - n = np.sqrt(len(timestamps_ms)) - throughput_imgps = (1000.0 * batch_size / timestamps_ms).mean() +def process_performance_stats(timestamps, batch_size, mode): + """ Get confidence intervals + + :param timestamps: Collection of timestamps + :param batch_size: Number of samples per batch + :param mode: Estimator's execution mode + :return: Stats + """ + timestamps_ms = 1000 * timestamps + throughput_imgps = (1000.0 * batch_size / timestamps_ms).mean() + stats = {f"throughput_{mode}": throughput_imgps, + f"latency_{mode}_mean": timestamps_ms.mean()} + for level in [90, 95, 99]: + stats.update({f"latency_{mode}_{level}": np.percentile(timestamps_ms, level)}) - stats = [("Throughput Avg", str(throughput_imgps)), - ('Latency Avg:', str(latency_ms))] - for ci, lvl in zip(["90%:", "95%:", "99%:"], - [1.645, 1.960, 2.576]): - stats.append(("Latency_"+ci, str(latency_ms + lvl * std / n))) return stats diff --git a/TensorFlow2/Segmentation/UNet_Medical/run.py b/TensorFlow2/Segmentation/UNet_Medical/runtime/run.py similarity index 75% rename from TensorFlow2/Segmentation/UNet_Medical/run.py rename to TensorFlow2/Segmentation/UNet_Medical/runtime/run.py index ffe0e305..f9e7deed 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/run.py +++ b/TensorFlow2/Segmentation/UNet_Medical/runtime/run.py @@ -19,8 +19,8 @@ from PIL import Image import horovod.tensorflow as hvd import tensorflow as tf -from utils.losses import partial_losses -from utils.parse_results import process_performance_stats +from runtime.losses import partial_losses +from runtime.parse_results import process_performance_stats def train(params, model, dataset, logger): @@ -35,7 +35,7 @@ def train(params, model, dataset, logger): ce_loss = tf.keras.metrics.Mean(name='ce_loss') f1_loss = tf.keras.metrics.Mean(name='dice_loss') checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) - if params.resume_training: + if params.resume_training and params.model_dir: checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)) @tf.function @@ -69,26 +69,30 @@ def train(params, model, dataset, logger): if params.benchmark: assert max_steps * hvd.size() > params.warmup_steps, \ "max_steps value has to be greater than warmup_steps" - timestamps = np.zeros((hvd.size(), max_steps * hvd.size() + 1), dtype=np.float32) + timestamps = [] for iteration, (images, labels) in enumerate(dataset.train_fn(drop_remainder=True)): - t0 = time() loss = train_step(images, labels, warmup_batch=iteration == 0).numpy() - timestamps[hvd.rank(), iteration] = time() - t0 + if iteration > params.warmup_steps: + timestamps.append(time()) if iteration >= max_steps * hvd.size(): break - timestamps = np.mean(timestamps, axis=0) + if hvd.rank() == 0: - stats = process_performance_stats(timestamps, params) - logger.log(step=(), - data={metric: value for (metric, value) in stats}) + deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]) + stats = process_performance_stats(deltas, hvd.size() * params.batch_size, mode="train") + logger.log(step=(), data=stats) else: for iteration, (images, labels) in enumerate(dataset.train_fn()): train_step(images, labels, warmup_batch=iteration == 0) - if (hvd.rank() == 0) and (iteration % params.log_every == 0): - logger.log(step=(iteration, max_steps), - data={"train_ce_loss": float(ce_loss.result()), - "train_dice_loss": float(f1_loss.result()), - "train_total_loss": float(f1_loss.result() + ce_loss.result())}) + if hvd.rank() == 0: + if iteration % params.log_every == 0: + logger.log(step=(iteration, max_steps), + data={"train_ce_loss": float(ce_loss.result()), + "train_dice_loss": float(f1_loss.result()), + "train_total_loss": float(f1_loss.result() + ce_loss.result())}) + + if (params.evaluate_every > 0) and (iteration % params.evaluate_every == 0): + evaluate(params, model, dataset, logger, restore_checkpoint=False) f1_loss.reset_states() ce_loss.reset_states() @@ -101,13 +105,15 @@ def train(params, model, dataset, logger): logger.flush() -def evaluate(params, model, dataset, logger): +def evaluate(params, model, dataset, logger, restore_checkpoint=True): + if params.fold is None: + print("No fold specified for evaluation. Please use --fold [int] to select a fold.") ce_loss = tf.keras.metrics.Mean(name='ce_loss') f1_loss = tf.keras.metrics.Mean(name='dice_loss') checkpoint = tf.train.Checkpoint(model=model) - checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial() + if params.model_dir and restore_checkpoint: + checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial() - @tf.function def validation_step(features, labels): output_map = model(features, training=False) crossentropy_loss, dice_loss = partial_losses(output_map, labels) @@ -130,7 +136,8 @@ def evaluate(params, model, dataset, logger): def predict(params, model, dataset, logger): checkpoint = tf.train.Checkpoint(model=model) - checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial() + if params.model_dir: + checkpoint.restore(tf.train.latest_checkpoint(params.model_dir)).expect_partial() @tf.function def prediction_step(features): @@ -139,16 +146,16 @@ def predict(params, model, dataset, logger): if params.benchmark: assert params.max_steps > params.warmup_steps, \ "max_steps value has to be greater than warmup_steps" - timestamps = np.zeros(params.max_steps + 1, dtype=np.float32) + timestamps = [] for iteration, images in enumerate(dataset.test_fn(count=None, drop_remainder=True)): - t0 = time() prediction_step(images) - timestamps[iteration] = time() - t0 + timestamps.append(time()) if iteration >= params.max_steps: break - stats = process_performance_stats(timestamps, params) - logger.log(step=(), - data={metric: value for (metric, value) in stats}) + + deltas = np.array([timestamps[i + 1] - timestamps[i] for i in range(len(timestamps) - 1)]) + stats = process_performance_stats(deltas, params.batch_size, mode="test") + logger.log(step=(), data=stats) else: predictions = np.concatenate([prediction_step(images).numpy() for images in dataset.test_fn(count=1)], axis=0) @@ -163,4 +170,6 @@ def predict(params, model, dataset, logger): compression="tiff_deflate", save_all=True, append_images=multipage_tif[1:]) + + print("Predictions saved at {}".format(output_dir)) logger.flush() diff --git a/TensorFlow2/Segmentation/UNet_Medical/utils/setup.py b/TensorFlow2/Segmentation/UNet_Medical/runtime/setup.py similarity index 90% rename from TensorFlow2/Segmentation/UNet_Medical/utils/setup.py rename to TensorFlow2/Segmentation/UNet_Medical/runtime/setup.py index 92f4ba3e..819f05f7 100644 --- a/TensorFlow2/Segmentation/UNet_Medical/utils/setup.py +++ b/TensorFlow2/Segmentation/UNet_Medical/runtime/setup.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import multiprocessing + import numpy as np import tensorflow as tf +import horovod.tensorflow as hvd import dllogger as logger -import horovod.tensorflow as hvd from dllogger import StdOutBackend, Verbosity, JSONStreamBackend @@ -32,6 +34,7 @@ def set_flags(params): os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_AUTOTUNE_THRESHOLD'] = '2' + os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0' np.random.seed(params.seed) tf.random.set_seed(params.seed) @@ -45,10 +48,11 @@ def set_flags(params): if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + tf.config.threading.set_intra_op_parallelism_threads(1) + tf.config.threading.set_inter_op_parallelism_threads(max(2, (multiprocessing.cpu_count() // hvd.size()) - 2)) + if params.use_amp: tf.keras.mixed_precision.experimental.set_policy('mixed_float16') - else: - os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0' def prepare_model_dir(params):