[SE3Transformer/DGLPyT] Better low memory mode
This commit is contained in:
parent
dcd3bbac09
commit
22d6621dcd
|
@ -42,7 +42,6 @@ RUN make -j8
|
|||
|
||||
FROM ${FROM_IMAGE_NAME}
|
||||
|
||||
RUN rm -rf /workspace/*
|
||||
WORKDIR /workspace/se3-transformer
|
||||
|
||||
# copy built DGL and install it
|
||||
|
@ -55,3 +54,5 @@ ADD . .
|
|||
|
||||
ENV DGLBACKEND=pytorch
|
||||
ENV OMP_NUM_THREADS=1
|
||||
|
||||
|
||||
|
|
|
@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model:
|
|||
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
||||
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
||||
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
||||
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
|
||||
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode:
|
||||
|
||||
| | FP32 | AMP
|
||||
|---|-----------------------|--------------------------
|
||||
|`--low_memory false` (default) | 4.7x | 7.1x
|
||||
|`--low_memory true` | 29.4x | 43.6x
|
||||
|
||||
|
||||
**Self-attention optimizations**
|
||||
|
||||
|
@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta
|
|||
- `--pooling`: Type of graph pooling (default: `max`)
|
||||
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
||||
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
||||
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
|
||||
- `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`)
|
||||
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
||||
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
||||
|
||||
|
@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod
|
|||
|
||||
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
||||
|
||||
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
|
||||
You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag.
|
||||
If no API key is set, `--wandb` will log the run anonymously to Weights & Biases.
|
||||
|
||||
**Checkpoints**
|
||||
|
||||
|
@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
|||
|
||||
### Changelog
|
||||
|
||||
November 2021:
|
||||
- Improved low memory mode to give further 6x memory savings
|
||||
- Disabled W&B logging by default
|
||||
- Fixed persistent workers when using one data loading process
|
||||
|
||||
October 2021:
|
||||
- Updated README performance tables
|
||||
- Fixed shape mismatch when using partially fused TFNs per output degree
|
||||
|
|
|
@ -46,7 +46,8 @@ class DataModule(ABC):
|
|||
if dist.is_initialized():
|
||||
dist.barrier(device_ids=[get_local_rank()])
|
||||
|
||||
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
||||
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
|
||||
**dataloader_kwargs}
|
||||
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
||||
|
||||
def prepare_data(self):
|
||||
|
|
|
@ -116,6 +116,7 @@ class AttentionBlockSE3(nn.Module):
|
|||
use_layer_norm: bool = False,
|
||||
max_degree: bool = 4,
|
||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||
low_memory: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
@ -140,7 +141,7 @@ class AttentionBlockSE3(nn.Module):
|
|||
|
||||
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
||||
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
||||
allow_fused_output=True)
|
||||
allow_fused_output=True, low_memory=low_memory)
|
||||
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
||||
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
||||
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
||||
|
|
|
@ -29,6 +29,7 @@ import dgl
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from dgl import DGLGraph
|
||||
from torch import Tensor
|
||||
from torch.cuda.nvtx import range as nvtx_range
|
||||
|
@ -185,7 +186,8 @@ class ConvSE3(nn.Module):
|
|||
self_interaction: bool = False,
|
||||
max_degree: int = 4,
|
||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||
allow_fused_output: bool = False
|
||||
allow_fused_output: bool = False,
|
||||
low_memory: bool = False
|
||||
):
|
||||
"""
|
||||
:param fiber_in: Fiber describing the input features
|
||||
|
@ -205,6 +207,7 @@ class ConvSE3(nn.Module):
|
|||
self.self_interaction = self_interaction
|
||||
self.max_degree = max_degree
|
||||
self.allow_fused_output = allow_fused_output
|
||||
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
|
||||
|
||||
# channels_in: account for the concatenation of edge features
|
||||
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
||||
|
@ -244,8 +247,8 @@ class ConvSE3(nn.Module):
|
|||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||
self.conv_in = nn.ModuleDict()
|
||||
for d_in, c_in in fiber_in:
|
||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
|
||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
|
||||
fuse_level=self.used_fuse_level, **common_args)
|
||||
else:
|
||||
|
@ -298,7 +301,9 @@ class ConvSE3(nn.Module):
|
|||
|
||||
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
||||
in_features_fused = torch.cat(in_features, dim=-1)
|
||||
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
||||
out = self.conv_checkpoint(
|
||||
self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused']
|
||||
)
|
||||
|
||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||
out = unfuse_features(out, self.fiber_out.degrees)
|
||||
|
@ -308,13 +313,16 @@ class ConvSE3(nn.Module):
|
|||
for degree_out in self.fiber_out.degrees:
|
||||
basis_used = basis[f'out{degree_out}_fused']
|
||||
out[str(degree_out)] = self._try_unpad(
|
||||
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
|
||||
basis_used)
|
||||
self.conv_checkpoint(
|
||||
self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used
|
||||
), basis_used)
|
||||
|
||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||
out = 0
|
||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
|
||||
out = out + self.conv_checkpoint(
|
||||
self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused']
|
||||
)
|
||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||
out = unfuse_features(out, self.fiber_out.degrees)
|
||||
else:
|
||||
|
@ -325,8 +333,9 @@ class ConvSE3(nn.Module):
|
|||
dict_key = f'{degree_in},{degree_out}'
|
||||
basis_used = basis.get(dict_key, None)
|
||||
out_feature = out_feature + self._try_unpad(
|
||||
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
|
||||
basis_used)
|
||||
self.conv_checkpoint(
|
||||
self.conv[dict_key], feature, invariant_edge_feats, basis_used
|
||||
), basis_used)
|
||||
out[str(degree_out)] = out_feature
|
||||
|
||||
for degree_out in self.fiber_out.degrees:
|
||||
|
|
|
@ -101,11 +101,11 @@ class SE3Transformer(nn.Module):
|
|||
self.tensor_cores = tensor_cores
|
||||
self.low_memory = low_memory
|
||||
|
||||
if low_memory and not tensor_cores:
|
||||
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
||||
|
||||
if low_memory:
|
||||
self.fuse_level = ConvSE3FuseLevel.NONE
|
||||
else:
|
||||
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
||||
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
||||
self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
|
||||
|
||||
graph_modules = []
|
||||
for i in range(num_layers):
|
||||
|
@ -116,7 +116,8 @@ class SE3Transformer(nn.Module):
|
|||
channels_div=channels_div,
|
||||
use_layer_norm=use_layer_norm,
|
||||
max_degree=self.max_degree,
|
||||
fuse_level=fuse_level))
|
||||
fuse_level=self.fuse_level,
|
||||
low_memory=low_memory))
|
||||
if norm:
|
||||
graph_modules.append(NormSE3(fiber_hidden))
|
||||
fiber_in = fiber_hidden
|
||||
|
@ -143,7 +144,7 @@ class SE3Transformer(nn.Module):
|
|||
|
||||
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
||||
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||
fully_fused=self.tensor_cores and not self.low_memory)
|
||||
fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
|
||||
|
||||
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
||||
|
||||
|
|
|
@ -62,6 +62,8 @@ PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=2
|
|||
help='Do an evaluation round every N epochs')
|
||||
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Minimize stdout output')
|
||||
PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Enable W&B logging')
|
||||
|
||||
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
||||
help='Benchmark mode')
|
||||
|
|
|
@ -32,7 +32,7 @@ from tqdm import tqdm
|
|||
from se3_transformer.runtime import gpu_affinity
|
||||
from se3_transformer.runtime.arguments import PARSER
|
||||
from se3_transformer.runtime.callbacks import BaseCallback
|
||||
from se3_transformer.runtime.loggers import DLLogger
|
||||
from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection
|
||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
||||
|
||||
|
||||
|
@ -87,7 +87,10 @@ if __name__ == '__main__':
|
|||
|
||||
major_cc, minor_cc = torch.cuda.get_device_capability()
|
||||
|
||||
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
||||
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
|
||||
if args.wandb:
|
||||
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
|
||||
logger = LoggerCollection(loggers)
|
||||
datamodule = QM9DataModule(**vars(args))
|
||||
model = SE3TransformerPooled(
|
||||
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||
|
@ -108,6 +111,7 @@ if __name__ == '__main__':
|
|||
nproc_per_node = torch.cuda.device_count()
|
||||
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||
model._set_static_graph()
|
||||
|
||||
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
||||
evaluate(model,
|
||||
|
|
|
@ -125,6 +125,7 @@ def train(model: nn.Module,
|
|||
|
||||
if dist.is_initialized():
|
||||
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||
model._set_static_graph()
|
||||
|
||||
model.train()
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||
|
@ -147,7 +148,8 @@ def train(model: nn.Module,
|
|||
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||
|
||||
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,
|
||||
args)
|
||||
if dist.is_initialized():
|
||||
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||
torch.distributed.all_reduce(loss)
|
||||
|
@ -163,7 +165,8 @@ def train(model: nn.Module,
|
|||
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
||||
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
||||
|
||||
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
||||
if not args.benchmark and (
|
||||
(args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
||||
evaluate(model, val_dataloader, callbacks, args)
|
||||
model.train()
|
||||
|
||||
|
@ -197,10 +200,10 @@ if __name__ == '__main__':
|
|||
logging.info(f'Using seed {args.seed}')
|
||||
seed_everything(args.seed)
|
||||
|
||||
logger = LoggerCollection([
|
||||
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
||||
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
||||
])
|
||||
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
|
||||
if args.wandb:
|
||||
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
|
||||
logger = LoggerCollection(loggers)
|
||||
|
||||
datamodule = QM9DataModule(**vars(args))
|
||||
model = SE3TransformerPooled(
|
||||
|
@ -236,5 +239,3 @@ if __name__ == '__main__':
|
|||
args)
|
||||
|
||||
logging.info('Training finished successfully')
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name='se3-transformer',
|
||||
packages=find_packages(),
|
||||
packages=find_packages(exclude=['tests']),
|
||||
include_package_data=True,
|
||||
version='1.0.0',
|
||||
version='1.1.0',
|
||||
description='PyTorch + DGL implementation of SE(3)-Transformers',
|
||||
author='Alexandre Milesi',
|
||||
author_email='alexandrem@nvidia.com',
|
||||
|
|
|
@ -25,7 +25,11 @@ import torch
|
|||
|
||||
from se3_transformer.model import SE3Transformer
|
||||
from se3_transformer.model.fiber import Fiber
|
||||
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||
|
||||
if __package__ is None or __package__ == '':
|
||||
from utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||
else:
|
||||
from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||
|
||||
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
||||
TOL = 1e-3
|
||||
|
|
|
@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided
|
|||
## Graph Neural Networks
|
||||
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
|
||||
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
|
||||
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
|
||||
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
|
||||
|
||||
|
||||
## NVIDIA support
|
||||
|
|
Loading…
Reference in a new issue