[SE3Transformer/PyT] Fix multi-GPU training bug

This commit is contained in:
Alexandre Milesi 2021-09-10 13:58:25 +02:00 committed by Krzysztof Kudrynski
parent 649776f79a
commit 8d31298c56
9 changed files with 13 additions and 11 deletions

View file

@ -472,7 +472,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x |
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
@ -482,7 +482,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x |
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
#### Training performance results

View file

View file

View file

0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh Normal file → Executable file
View file

View file

@ -3,8 +3,8 @@
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
NUM_EPOCHS=${3:-300}
LEARNING_RATE=${4:-0.004}
NUM_EPOCHS=${3:-130}
LEARNING_RATE=${4:-0.01}
WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
@ -17,7 +17,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0
--batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \
--min_lr 0.0001 \
--min_lr 0.00001 \
--weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \
--norm \

View file

@ -80,7 +80,8 @@ class QM9DataModule(DataModule):
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
if precompute_bases:
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs)
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
num_workers=num_workers, **qm9_kwargs)
else:
full_dataset = QM9EdgeDataset(**qm9_kwargs)
@ -134,7 +135,7 @@ class QM9DataModule(DataModule):
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs):
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
"""
:param bases_kwargs: Arguments to feed the bases computation function
:param batch_size: Batch size to use when iterating over the dataset for computing bases
@ -142,13 +143,14 @@ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
self.bases_kwargs = bases_kwargs
self.batch_size = batch_size
self.bases = None
self.num_workers = num_workers
super().__init__(*args, **kwargs)
def load(self):
super().load()
# Iterate through the dataset and compute bases (pairwise only)
# Potential improvement: use multi-GPU and reduction
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size,
# Potential improvement: use multi-GPU and gather
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
bases = []
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',

View file

@ -81,7 +81,7 @@ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callb
return checkpoint['epoch']
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
@ -147,7 +147,7 @@ def train(model: nn.Module,
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)