[SE3Transformer/PyT] Fix multi-GPU training bug
This commit is contained in:
parent
649776f79a
commit
8d31298c56
|
@ -472,7 +472,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
|||
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
||||
| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x |
|
||||
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
|
||||
|
||||
|
||||
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
||||
|
@ -482,7 +482,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
|||
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
||||
| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x |
|
||||
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
||||
|
||||
|
||||
#### Training performance results
|
||||
|
|
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh
Normal file → Executable file
6
PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh
Normal file → Executable file
6
PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh
Normal file → Executable file
|
@ -3,8 +3,8 @@
|
|||
# CLI args with defaults
|
||||
BATCH_SIZE=${1:-240}
|
||||
AMP=${2:-true}
|
||||
NUM_EPOCHS=${3:-300}
|
||||
LEARNING_RATE=${4:-0.004}
|
||||
NUM_EPOCHS=${3:-130}
|
||||
LEARNING_RATE=${4:-0.01}
|
||||
WEIGHT_DECAY=${5:-0.1}
|
||||
|
||||
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||
|
@ -17,7 +17,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0
|
|||
--batch_size "$BATCH_SIZE" \
|
||||
--epochs "$NUM_EPOCHS" \
|
||||
--lr "$LEARNING_RATE" \
|
||||
--min_lr 0.0001 \
|
||||
--min_lr 0.00001 \
|
||||
--weight_decay "$WEIGHT_DECAY" \
|
||||
--use_layer_norm \
|
||||
--norm \
|
||||
|
|
|
@ -80,7 +80,8 @@ class QM9DataModule(DataModule):
|
|||
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
||||
if precompute_bases:
|
||||
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
||||
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs)
|
||||
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
||||
num_workers=num_workers, **qm9_kwargs)
|
||||
else:
|
||||
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
||||
|
||||
|
@ -134,7 +135,7 @@ class QM9DataModule(DataModule):
|
|||
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
||||
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
||||
|
||||
def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs):
|
||||
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
||||
"""
|
||||
:param bases_kwargs: Arguments to feed the bases computation function
|
||||
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
||||
|
@ -142,13 +143,14 @@ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
|||
self.bases_kwargs = bases_kwargs
|
||||
self.batch_size = batch_size
|
||||
self.bases = None
|
||||
self.num_workers = num_workers
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load(self):
|
||||
super().load()
|
||||
# Iterate through the dataset and compute bases (pairwise only)
|
||||
# Potential improvement: use multi-GPU and reduction
|
||||
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size,
|
||||
# Potential improvement: use multi-GPU and gather
|
||||
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
||||
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
||||
bases = []
|
||||
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
||||
|
|
|
@ -81,7 +81,7 @@ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callb
|
|||
return checkpoint['epoch']
|
||||
|
||||
|
||||
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||
losses = []
|
||||
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
||||
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
||||
|
@ -147,7 +147,7 @@ def train(model: nn.Module,
|
|||
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||
|
||||
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||
if dist.is_initialized():
|
||||
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||
torch.distributed.all_reduce(loss)
|
||||
|
|
Loading…
Reference in a new issue