From 8d31298c56ccf2098b68098cac667a88a429eaad Mon Sep 17 00:00:00 2001 From: Alexandre Milesi Date: Fri, 10 Sep 2021 13:58:25 +0200 Subject: [PATCH] [SE3Transformer/PyT] Fix multi-GPU training bug --- PyTorch/DrugDiscovery/SE3Transformer/README.md | 4 ++-- .../SE3Transformer/scripts/benchmark_inference.sh | 0 .../SE3Transformer/scripts/benchmark_train.sh | 0 .../scripts/benchmark_train_multi_gpu.sh | 0 .../DrugDiscovery/SE3Transformer/scripts/predict.sh | 0 PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh | 0 .../SE3Transformer/scripts/train_multi_gpu.sh | 6 +++--- .../SE3Transformer/se3_transformer/data_loading/qm9.py | 10 ++++++---- .../SE3Transformer/se3_transformer/runtime/training.py | 4 ++-- 9 files changed, 13 insertions(+), 11 deletions(-) mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh mode change 100644 => 100755 PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh diff --git a/PyTorch/DrugDiscovery/SE3Transformer/README.md b/PyTorch/DrugDiscovery/SE3Transformer/README.md index 59b2c904..a4afc585 100644 --- a/PyTorch/DrugDiscovery/SE3Transformer/README.md +++ b/PyTorch/DrugDiscovery/SE3Transformer/README.md @@ -472,7 +472,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t | GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) | |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| | 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x | -| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x | +| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x | ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB) @@ -482,7 +482,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t | GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) | |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| | 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x | -| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x | +| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x | #### Training performance results diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh old mode 100644 new mode 100755 diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh old mode 100644 new mode 100755 diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh old mode 100644 new mode 100755 diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh old mode 100644 new mode 100755 diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh old mode 100644 new mode 100755 diff --git a/PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh b/PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh old mode 100644 new mode 100755 index f66eb888..a1344444 --- a/PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh +++ b/PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh @@ -3,8 +3,8 @@ # CLI args with defaults BATCH_SIZE=${1:-240} AMP=${2:-true} -NUM_EPOCHS=${3:-300} -LEARNING_RATE=${4:-0.004} +NUM_EPOCHS=${3:-130} +LEARNING_RATE=${4:-0.01} WEIGHT_DECAY=${5:-0.1} # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', @@ -17,7 +17,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --batch_size "$BATCH_SIZE" \ --epochs "$NUM_EPOCHS" \ --lr "$LEARNING_RATE" \ - --min_lr 0.0001 \ + --min_lr 0.00001 \ --weight_decay "$WEIGHT_DECAY" \ --use_layer_norm \ --norm \ diff --git a/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py b/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py index b4583986..4cbc753e 100644 --- a/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py +++ b/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/data_loading/qm9.py @@ -80,7 +80,8 @@ class QM9DataModule(DataModule): qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir)) if precompute_bases: bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp) - full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs) + full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, + num_workers=num_workers, **qm9_kwargs) else: full_dataset = QM9EdgeDataset(**qm9_kwargs) @@ -134,7 +135,7 @@ class QM9DataModule(DataModule): class CachedBasesQM9EdgeDataset(QM9EdgeDataset): """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """ - def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs): + def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs): """ :param bases_kwargs: Arguments to feed the bases computation function :param batch_size: Batch size to use when iterating over the dataset for computing bases @@ -142,13 +143,14 @@ class CachedBasesQM9EdgeDataset(QM9EdgeDataset): self.bases_kwargs = bases_kwargs self.batch_size = batch_size self.bases = None + self.num_workers = num_workers super().__init__(*args, **kwargs) def load(self): super().load() # Iterate through the dataset and compute bases (pairwise only) - # Potential improvement: use multi-GPU and reduction - dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, + # Potential improvement: use multi-GPU and gather + dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples])) bases = [] for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases', diff --git a/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py b/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py index 2d66f08f..53122779 100644 --- a/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py +++ b/PyTorch/DrugDiscovery/SE3Transformer/se3_transformer/runtime/training.py @@ -81,7 +81,7 @@ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callb return checkpoint['epoch'] -def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): +def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): losses = [] for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): @@ -147,7 +147,7 @@ def train(model: nn.Module, if isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch_idx) - loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) + loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) if dist.is_initialized(): loss = torch.tensor(loss, dtype=torch.float, device=device) torch.distributed.all_reduce(loss)