[SE3Transformer/PyT] Fix multi-GPU training bug

This commit is contained in:
Alexandre Milesi 2021-09-10 13:58:25 +02:00 committed by Krzysztof Kudrynski
parent 649776f79a
commit 8d31298c56
9 changed files with 13 additions and 11 deletions

View file

@ -472,7 +472,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) | | GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x | | 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x | | 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB) ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
@ -482,7 +482,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) | | GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:| |:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x | | 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x | | 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
#### Training performance results #### Training performance results

View file

View file

View file

0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh Normal file → Executable file
View file

View file

@ -3,8 +3,8 @@
# CLI args with defaults # CLI args with defaults
BATCH_SIZE=${1:-240} BATCH_SIZE=${1:-240}
AMP=${2:-true} AMP=${2:-true}
NUM_EPOCHS=${3:-300} NUM_EPOCHS=${3:-130}
LEARNING_RATE=${4:-0.004} LEARNING_RATE=${4:-0.01}
WEIGHT_DECAY=${5:-0.1} WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
@ -17,7 +17,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0
--batch_size "$BATCH_SIZE" \ --batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \ --epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \ --lr "$LEARNING_RATE" \
--min_lr 0.0001 \ --min_lr 0.00001 \
--weight_decay "$WEIGHT_DECAY" \ --weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \ --use_layer_norm \
--norm \ --norm \

View file

@ -80,7 +80,8 @@ class QM9DataModule(DataModule):
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir)) qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
if precompute_bases: if precompute_bases:
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp) bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs) full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
num_workers=num_workers, **qm9_kwargs)
else: else:
full_dataset = QM9EdgeDataset(**qm9_kwargs) full_dataset = QM9EdgeDataset(**qm9_kwargs)
@ -134,7 +135,7 @@ class QM9DataModule(DataModule):
class CachedBasesQM9EdgeDataset(QM9EdgeDataset): class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """ """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs): def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
""" """
:param bases_kwargs: Arguments to feed the bases computation function :param bases_kwargs: Arguments to feed the bases computation function
:param batch_size: Batch size to use when iterating over the dataset for computing bases :param batch_size: Batch size to use when iterating over the dataset for computing bases
@ -142,13 +143,14 @@ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
self.bases_kwargs = bases_kwargs self.bases_kwargs = bases_kwargs
self.batch_size = batch_size self.batch_size = batch_size
self.bases = None self.bases = None
self.num_workers = num_workers
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def load(self): def load(self):
super().load() super().load()
# Iterate through the dataset and compute bases (pairwise only) # Iterate through the dataset and compute bases (pairwise only)
# Potential improvement: use multi-GPU and reduction # Potential improvement: use multi-GPU and gather
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples])) collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
bases = [] bases = []
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases', for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',

View file

@ -81,7 +81,7 @@ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callb
return checkpoint['epoch'] return checkpoint['epoch']
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = [] losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
@ -147,7 +147,7 @@ def train(model: nn.Module,
if isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx) train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized(): if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device) loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss) torch.distributed.all_reduce(loss)