Skip to content

Commit bb33315

Browse files
dongyang0122dongymingxin-zheng
authored
Update notebooks of acceleration and performance (#1179)
Update notebooks of acceleration and performance. ### Description Update notebooks of acceleration and performance. acceleration/automatic_mixed_precision.ipynb acceleration/dataset_type_performance.ipynb acceleration/fast_training_tutorial.ipynb performance_profiling ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` Signed-off-by: dongy <[email protected]> Signed-off-by: dongyang0122 <[email protected]> Signed-off-by: Mingxin Zheng <[email protected]> Co-authored-by: dongy <[email protected]> Co-authored-by: Mingxin Zheng <[email protected]> Co-authored-by: Mingxin Zheng <[email protected]>
1 parent 9a9106d commit bb33315

15 files changed

+208
-263
lines changed

acceleration/automatic_mixed_precision.ipynb

+55-50
Large diffs are not rendered by default.

acceleration/dataset_type_performance.ipynb

+46-32
Large diffs are not rendered by default.

acceleration/distributed_training/brats_training_ddp.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020 MONAI Consortium
1+
# Copyright (c) MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -17,7 +17,7 @@
1717
So it's more complicated than other distributed training demo examples.
1818
1919
Under default settings, each single GPU needs to use ~12GB memory for network training. In addition, in order to
20-
cache the whole dataset, ~100GB GPU memory are necessary. Therefore, at least 5 NVIDIA TESLA V100 (32G) are needed.
20+
cache the whole dataset, ~100GB GPU memory are necessary. Therefore, at least 2 NVIDIA TESLA A100 (80G) are needed.
2121
If you do not have enough GPU memory, you can try to decrease the input parameter `cache_rate`.
2222
2323
Main steps to set up the distributed training:
@@ -27,7 +27,7 @@
2727
`--nproc_per_node=NUM_GPUS_PER_NODE`
2828
`--nnodes=NUM_NODES`
2929
`--node_rank=INDEX_CURRENT_NODE`
30-
`--master_addr="192.168.1.1"`
30+
`--master_addr="localhost"`
3131
`--master_port=1234`
3232
For more details, refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py.
3333
Alternatively, we can also use `torch.multiprocessing.spawn` to start program, but it that case, need to handle
@@ -45,7 +45,7 @@
4545
Example script to execute this program on every node:
4646
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
4747
--nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
48-
--master_addr="192.168.1.1" --master_port=1234
48+
--master_addr="localhost" --master_port=1234
4949
brats_training_ddp.py -d DIR_OF_TESTDATA
5050
5151
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3].
@@ -395,7 +395,7 @@ def main():
395395

396396
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
397397
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
398-
# --master_addr="192.168.1.1" --master_port=1234
398+
# --master_addr="localhost" --master_port=1234
399399
# brats_training_ddp.py -d DIR_OF_TESTDATA
400400

401401
if __name__ == "__main__":

acceleration/fast_model_training_guide.md

+27-17
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ To provide an overview of the fast training techniques in practice, this documen
2121
* [Execute transforms on GPU](#2-execute-transforms-on-gpu)
2222
* [Adapt `cuCIM` to execute GPU transforms](#3-adapt-cucim-to-execute-gpu-transforms)
2323
* [Cache IO and transforms data to GPU](#4-cache-io-and-transforms-data-to-gpu)
24-
* [Leveraging multi-GPU](#leveraging-multi-gpu)
24+
* [Leveraging multi-GPU distributed training](#leveraging-multi-gpu-distributed-training)
2525
* Demonstration of multi-GPU training for performance improvement.
2626
* [Leveraging multi-node distributed training](#leveraging-multi-node-distributed-training)
2727
* Demonstration of distributed multi-node training for performance improvement.
@@ -182,7 +182,8 @@ MONAI provides a multi-thread `CacheDataset` and `LMDBDataset` to accelerate the
182182
### 2. Cache intermediate outcomes into persistent storage
183183

184184
`PersistentDataset` is similar to `CacheDataset`, where the caches are persisted to disk storage or LMDB for rapid retrieval across experimental runs (as is the case when tuning hyperparameters), or when the entire size of the dataset exceeds available memory. `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/dataset_type_performance.ipynb).
185-
![cachedataset speed](../figures/datasets_speed.png) with an SSD storage.
185+
186+
![cachedataset speed](../figures/datasets_speed.png)
186187

187188
### 3. SmartCache mechanism for large datasets
188189

@@ -208,7 +209,14 @@ a `ThreadDataLoader` example is available at [Spleen fast training tutorial](htt
208209

209210
## Algorithmic improvement
210211

211-
In most deep learning applications, algorithmic improvement has been witnessed to be effective for boosting training efficiency and performance (for example, from AlexNet to ResNet). The improvement may come from a novel loss function, or a sophisticated optimizer, or a different learning rate scheduler, or the combination of all previous items. For our demo applications of 3D medical image segmentation, we would like to further speed up training from the algorithmic perspective. The default loss function is soft Dice loss. And we changed it to `DiceCELoss` from MONAI to further improve the model convergence. Because the `DiceCELoss` combines both Dice loss and multi-class cross-entropy loss (which is suitable for the softmax formulation), and balance the importance of global and pixel-wise accuracies. The segmentation quality can be largely improved. The following figure shows the great improvement on model convergence after we change Dice loss to `DiceCELoss`, with or without enabling automated mixed precision (AMP).
212+
In most deep learning applications, algorithmic improvement has been witnessed to be effective in boosting training efficiency and performance (for example, from AlexNet to ResNet).
213+
The improvement may come from a novel loss function, a sophisticated optimizer, a different learning rate scheduler, or a combination of all previous items.
214+
For our demo applications of 3D medical image segmentation, we would like to further speed up training from the algorithmic perspective.
215+
The default loss function is soft Dice loss.
216+
And we changed it to `DiceCELoss` from MONAI to further improve the model convergence,
217+
because the `DiceCELoss` combines both Dice loss and multi-class cross-entropy loss (which is suitable for the softmax formulation) and balances the importance of global and pixel-wise accuracies.
218+
The segmentation quality can be largely improved.
219+
The following figure shows the great improvement in model convergence after we change the Dice loss to `DiceCELoss`, with or without enabling AMP.
212220

213221
![diceceloss](../figures/diceceloss.png)
214222

@@ -225,8 +233,11 @@ In 2017, NVIDIA researchers developed a methodology for mixed-precision training
225233

226234
For the PyTorch 1.6 release, developers at NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, `torch.cuda.amp`.
227235

228-
MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP. And we tried to compare the training speed of spleen segmentation task if AMP ON/OFF on NVIDIA V100 GPU with CUDA 11, obtained some benchmark results:
229-
![amp v100 results](../figures/amp_training_v100.png)
236+
MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP.
237+
We tried to compare the training speed of the spleen segmentation task if AMP ON/OFF on NVIDIA A100 GPU with CUDA 11 and obtained some benchmark results:
238+
239+
![amp a100 results](../figures/amp_training_a100.png)
240+
230241
AMP tutorial is available at [AMP tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/automatic_mixed_precision.ipynb).
231242

232243
### 2. Execute transforms on GPU
@@ -274,28 +285,26 @@ dataset = CacheDataset(..., transform=train_trans)
274285
Here we convert to PyTorch `Tensor` with `EnsureTyped` transform and move data to GPU with `ToDeviced` transform. `CacheDataset` caches the transform results until `ToDeviced`, so it is in GPU memory. Then in every epoch, the program fetches cached data from GPU memory and only execute the random transform `RandCropByPosNegLabeld` on GPU directly.
275286
GPU caching example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb).
276287

277-
## Leveraging multi-GPU
288+
## Leveraging multi-GPU distributed training
278289

279-
When we have fully utilized a single GPU during training, a natural optimization idea is to partition the dataset and execute model training in parallel on multiple GPUs.
290+
When we have fully utilized a single GPU during training, a straightforward optimization idea is to partition the dataset and execute model training in parallel on multiple GPUs.
280291

281292
Additionally, with more GPU devices, we can achieve more benefits:
282293
- Some training algorithms can converge faster with a larger batch size and the training progress is more stable.
283-
- If caching data in GPU memory, every GPU only needs to cache a partition, so we can use larger cache rate to cache more data in total to accelerate training.
294+
- If caching data in GPU memory, every GPU only needs to cache a partition, so we can use a larger cache rate to cache more data in total to accelerate training. Caching data to GPU can largely reduce CPU-based operations during model training. It can greatly improve the model training efficiency.
284295

285296
For example, during the training of brain tumor segmentation task, with 8 GPUs, we can cache all the data in GPU memory directly and execute the following transforms on GPU device, so it's more than `10x` faster than single GPU training. More details are available at [BraTS distributed training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py).
286297

287298
## Leveraging multi-node distributed training
288299

289-
Distributed data parallelism (DDP) is an important feature of PyTorch to connect multiple GPU devices in multiple nodes to train or evaluate models, it can continuously improve the training speed when we fully leveraged multiple GPUs on a single node.
300+
Distributed data parallelism (DDP) is an important feature of PyTorch to connect multiple GPU devices in multiple nodes to train or evaluate models. It can further improve the training speed when we fully leveraged multiple GPUs on multiple nodes.
290301

291-
The distributed data parallel APIs of MONAI are compatible with the native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides rich demos for reference: train/evaluate with `PyTorch DDP`, train/evaluate with `Horovod`, train/evaluate with `Ignite DDP`, partition dataset and train with `SmartCacheDataset`, as well as a real-world training example based on Decathlon challenge Task01 - Brain Tumor segmentation.
302+
The distributed data parallel APIs of MONAI are compatible with the native PyTorch distributed module, PyTorch-ignite distributed module, Horovod, XLA, and the SLURM platform. Here we provide [a real-world training example](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py) based on [Decathlon challenge](http://medicaldecathlon.com/) Task01 - Brain Tumor segmentation using the module `torch.distributed.launch`.
292303

293304
For more details about the PyTorch distributed training setup, please refer to: https://pytorch.org/docs/stable/distributed.html.
294305

295306
And if using [SLURM](https://developer.nvidia.com/slurm) workload manager, please refer to [SLURM + Singularity MONAI example](https://github.com/UFResearchComputing/MultiNode_MONAI_example).
296307

297-
We obtained U-Net performance benchmarks of Brain tumor segmentation task for reference (based on CUDA 11, NVIDIA V100 GPUs):
298-
![distributed training results](../figures/distributed_training.png)
299308
More details are available at [BraTS distributed training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py).
300309

301310
## Examples
@@ -305,13 +314,13 @@ With all the above strategies, in this section, we introduce how to apply them t
305314
### 1. Spleen segmentation
306315

307316
- Select the algorithms based on the experiments.
308-
As a binary segmentation task, we replaced the baseline `Dice` loss with a `DiceCE` loss, it can help improve the convergence. And we tried to analyze the training curve and tuned different parameters of the network and tested several numerical optimizers, finally replaced the baseline `Adam` optimizer with `SGD`. To achieve the target metric (`mean Dice = 0.94` of the `foreground` channel only) it reduces the number of training epochs from 280 to 60.
317+
As a binary segmentation task, we replaced the baseline `Dice` loss with a `DiceCE` loss, it can help improve the convergence. And we tried to analyze the training curve and tuned different parameters of the network and tested several numerical optimizers, finally replaced the baseline `Adam` optimizer with `SGD`. To achieve the target metric (`mean Dice = 0.94` of the `foreground` channel only) it reduces the number of training epochs from 165 to 95.
309318
- Optimize GPU utilization.
310319
1. With `AMP`, the training speed is significantly improved and can achieve almost the same validation metric as without `AMP`.
311-
2. The deterministic transform results of all the spleen dataset is around 8 GB, which can be cached in a V100 GPU memory. So, we cached all the data in GPU memory and executed the following transforms in GPU directly.
320+
2. The deterministic transform results of all the spleen dataset is around 8 GB, which can be cached in a A100 GPU memory. So, we cached all the data in GPU memory and executed the following transforms in GPU directly.
312321
- Replace `DataLoader` with `ThreadDataLoader`. As all the data are cached in GPU, the computation of randomized transforms is on GPU and light-weighted, `ThreadDataLoader` help avoid the IPC cost of multi-processing in `DataLoader` and increase the GPU utilization.
313322

314-
In summary, with a V100 GPU and the target validation `mean dice = 0.94` of the `forground` channel only, it's more than `100x` speedup compared with the Pytorch regular implementation when achieving the same metric (validation accuracies). And every epoch is `20x` faster than regular training.
323+
In summary, with a A100 GPU and the target validation `mean dice = 0.94` of the `forground` channel only, it's more than `150x` speedup compared with the Pytorch regular implementation when achieving the same metric (validation accuracies). And every epoch is `50x` faster than regular training.
315324
![spleen fast training](../figures/fast_training.png)
316325

317326
More details are available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb).
@@ -328,9 +337,10 @@ More details are available at [Spleen fast training tutorial](https://github.com
328337
1. Single GPU cannot cache all the data in memory, so we split the dataset into eight parts and cache the deterministic transforms result in eight GPUs to avoid duplicated deterministic transforms and `CPU->GPU sync` in every epoch.
329338
2. We executed all the random augmentations in GPU directly with the `ThreadDataLoader`. The GPU utilization of all the eight GPUs was always almost `100%` during training:
330339
![brats gpu utilization](../figures/brats_gpu_utilization.png)
331-
3. As we already fully leveraged the GPUs, we continuously optimize the training with multiple nodes (32 V100 GPUs in four nodes). The GPU utilization of all the 32 GPUs was always `97%` during training.
340+
3. As we already fully leveraged the GPUs, we continuously optimize the training with multiple nodes (32 A100 GPUs in four nodes). The GPU utilization of all the 32 GPUs was always `97%` during training.
341+
342+
In summary, combining the optimization strategies, the training time of eight A100 GPUs to achieve the target validation metric was around 40 minutes, which is more than `11x` faster than the baseline with a single GPU. Using four 8-GPU nodes can speed up model processing by `30x` the baseline performance. Our results are achieved based on TensorFloat-32 (TF32) precision format as the default setting in the docker image.
332343

333-
In summary, combining the optimization strategies, the training time of eight V100 GPUs to achieve the target validation metric was around 40 minutes, which is more than `13x` faster than the baseline with a single GPU. And the training time of 32 V100 GPUs was around `13` minutes, which is `40x` faster than the baseline:
334344
![brats benchmark](../figures/brats_benchmark.png)
335345

336346
More details are available at [BraTS distributed training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py).

0 commit comments

Comments
 (0)