[SE3Transformer/PyT] Fix multi-GPU training bug
This commit is contained in:
parent
649776f79a
commit
8d31298c56
|
@ -472,7 +472,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
||||||
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
||||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
||||||
| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x |
|
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
|
||||||
|
|
||||||
|
|
||||||
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
||||||
|
@ -482,7 +482,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
||||||
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
||||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
||||||
| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x |
|
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
||||||
|
|
||||||
|
|
||||||
#### Training performance results
|
#### Training performance results
|
||||||
|
|
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_inference.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/predict.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh
Normal file → Executable file
0
PyTorch/DrugDiscovery/SE3Transformer/scripts/train.sh
Normal file → Executable file
6
PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh
Normal file → Executable file
6
PyTorch/DrugDiscovery/SE3Transformer/scripts/train_multi_gpu.sh
Normal file → Executable file
|
@ -3,8 +3,8 @@
|
||||||
# CLI args with defaults
|
# CLI args with defaults
|
||||||
BATCH_SIZE=${1:-240}
|
BATCH_SIZE=${1:-240}
|
||||||
AMP=${2:-true}
|
AMP=${2:-true}
|
||||||
NUM_EPOCHS=${3:-300}
|
NUM_EPOCHS=${3:-130}
|
||||||
LEARNING_RATE=${4:-0.004}
|
LEARNING_RATE=${4:-0.01}
|
||||||
WEIGHT_DECAY=${5:-0.1}
|
WEIGHT_DECAY=${5:-0.1}
|
||||||
|
|
||||||
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||||
|
@ -17,7 +17,7 @@ python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0
|
||||||
--batch_size "$BATCH_SIZE" \
|
--batch_size "$BATCH_SIZE" \
|
||||||
--epochs "$NUM_EPOCHS" \
|
--epochs "$NUM_EPOCHS" \
|
||||||
--lr "$LEARNING_RATE" \
|
--lr "$LEARNING_RATE" \
|
||||||
--min_lr 0.0001 \
|
--min_lr 0.00001 \
|
||||||
--weight_decay "$WEIGHT_DECAY" \
|
--weight_decay "$WEIGHT_DECAY" \
|
||||||
--use_layer_norm \
|
--use_layer_norm \
|
||||||
--norm \
|
--norm \
|
||||||
|
|
|
@ -80,7 +80,8 @@ class QM9DataModule(DataModule):
|
||||||
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
||||||
if precompute_bases:
|
if precompute_bases:
|
||||||
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
||||||
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs)
|
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
||||||
|
num_workers=num_workers, **qm9_kwargs)
|
||||||
else:
|
else:
|
||||||
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
||||||
|
|
||||||
|
@ -134,7 +135,7 @@ class QM9DataModule(DataModule):
|
||||||
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
||||||
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
||||||
|
|
||||||
def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs):
|
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param bases_kwargs: Arguments to feed the bases computation function
|
:param bases_kwargs: Arguments to feed the bases computation function
|
||||||
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
||||||
|
@ -142,13 +143,14 @@ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
||||||
self.bases_kwargs = bases_kwargs
|
self.bases_kwargs = bases_kwargs
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.bases = None
|
self.bases = None
|
||||||
|
self.num_workers = num_workers
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
super().load()
|
super().load()
|
||||||
# Iterate through the dataset and compute bases (pairwise only)
|
# Iterate through the dataset and compute bases (pairwise only)
|
||||||
# Potential improvement: use multi-GPU and reduction
|
# Potential improvement: use multi-GPU and gather
|
||||||
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size,
|
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
||||||
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
||||||
bases = []
|
bases = []
|
||||||
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
||||||
|
|
|
@ -81,7 +81,7 @@ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callb
|
||||||
return checkpoint['epoch']
|
return checkpoint['epoch']
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||||
losses = []
|
losses = []
|
||||||
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
||||||
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
||||||
|
@ -147,7 +147,7 @@ def train(model: nn.Module,
|
||||||
if isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch_idx)
|
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||||
|
|
||||||
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||||
torch.distributed.all_reduce(loss)
|
torch.distributed.all_reduce(loss)
|
||||||
|
|
Loading…
Reference in a new issue