[SE3Transformer/DGLPyT] Better low memory mode

This commit is contained in:
Alexandre Milesi 2021-11-01 17:49:17 +01:00
parent dcd3bbac09
commit 22d6621dcd
19 changed files with 1610 additions and 1574 deletions

View file

@ -42,7 +42,6 @@ RUN make -j8
FROM ${FROM_IMAGE_NAME}
RUN rm -rf /workspace/*
WORKDIR /workspace/se3-transformer
# copy built DGL and install it
@ -55,3 +54,5 @@ ADD . .
ENV DGLBACKEND=pytorch
ENV OMP_NUM_THREADS=1

View file

@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model:
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode:
| | FP32 | AMP
|---|-----------------------|--------------------------
|`--low_memory false` (default) | 4.7x | 7.1x
|`--low_memory true` | 29.4x | 43.6x
**Self-attention optimizations**
@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta
- `--pooling`: Type of graph pooling (default: `max`)
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
- `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`)
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
- `--num_channels`: Number of channels for the hidden features (default: `32`)
@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag.
If no API key is set, `--wandb` will log the run anonymously to Weights & Biases.
**Checkpoints**
@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog
November 2021:
- Improved low memory mode to give further 6x memory savings
- Disabled W&B logging by default
- Fixed persistent workers when using one data loading process
October 2021:
- Updated README performance tables
- Fixed shape mismatch when using partially fused TFNs per output degree

View file

@ -46,7 +46,8 @@ class DataModule(ABC):
if dist.is_initialized():
dist.barrier(device_ids=[get_local_rank()])
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
**dataloader_kwargs}
self.ds_train, self.ds_val, self.ds_test = None, None, None
def prepare_data(self):

View file

@ -116,6 +116,7 @@ class AttentionBlockSE3(nn.Module):
use_layer_norm: bool = False,
max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
low_memory: bool = False,
**kwargs
):
"""
@ -140,7 +141,7 @@ class AttentionBlockSE3(nn.Module):
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
allow_fused_output=True)
allow_fused_output=True, low_memory=low_memory)
self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)

View file

@ -29,6 +29,7 @@ import dgl
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
@ -185,7 +186,8 @@ class ConvSE3(nn.Module):
self_interaction: bool = False,
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False
allow_fused_output: bool = False,
low_memory: bool = False
):
"""
:param fiber_in: Fiber describing the input features
@ -205,6 +207,7 @@ class ConvSE3(nn.Module):
self.self_interaction = self_interaction
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
# channels_in: account for the concatenation of edge features
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
@ -244,8 +247,8 @@ class ConvSE3(nn.Module):
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
else:
@ -298,7 +301,9 @@ class ConvSE3(nn.Module):
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
in_features_fused = torch.cat(in_features, dim=-1)
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
out = self.conv_checkpoint(
self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused']
)
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
@ -308,13 +313,16 @@ class ConvSE3(nn.Module):
for degree_out in self.fiber_out.degrees:
basis_used = basis[f'out{degree_out}_fused']
out[str(degree_out)] = self._try_unpad(
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
basis_used)
self.conv_checkpoint(
self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used
), basis_used)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
out = out + self.conv_checkpoint(
self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused']
)
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
else:
@ -325,8 +333,9 @@ class ConvSE3(nn.Module):
dict_key = f'{degree_in},{degree_out}'
basis_used = basis.get(dict_key, None)
out_feature = out_feature + self._try_unpad(
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
basis_used)
self.conv_checkpoint(
self.conv[dict_key], feature, invariant_edge_feats, basis_used
), basis_used)
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:

View file

@ -101,11 +101,11 @@ class SE3Transformer(nn.Module):
self.tensor_cores = tensor_cores
self.low_memory = low_memory
if low_memory and not tensor_cores:
logging.warning('Low memory mode will have no effect with no Tensor Cores')
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
if low_memory:
self.fuse_level = ConvSE3FuseLevel.NONE
else:
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
graph_modules = []
for i in range(num_layers):
@ -116,7 +116,8 @@ class SE3Transformer(nn.Module):
channels_div=channels_div,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=fuse_level))
fuse_level=self.fuse_level,
low_memory=low_memory))
if norm:
graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
@ -143,7 +144,7 @@ class SE3Transformer(nn.Module):
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.tensor_cores and not self.low_memory)
fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)

View file

@ -62,6 +62,8 @@ PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=2
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False,
help='Enable W&B logging')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')

View file

@ -32,7 +32,7 @@ from tqdm import tqdm
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback
from se3_transformer.runtime.loggers import DLLogger
from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection
from se3_transformer.runtime.utils import to_cuda, get_local_rank
@ -87,7 +87,10 @@ if __name__ == '__main__':
major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
logger = LoggerCollection(loggers)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
@ -108,6 +111,7 @@ if __name__ == '__main__':
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model,

View file

@ -125,6 +125,7 @@ def train(model: nn.Module,
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
@ -147,7 +148,8 @@ def train(model: nn.Module,
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,
args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
@ -163,7 +165,8 @@ def train(model: nn.Module,
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
if not args.benchmark and (
(args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
evaluate(model, val_dataloader, callbacks, args)
model.train()
@ -197,10 +200,10 @@ if __name__ == '__main__':
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
logger = LoggerCollection([
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
])
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
logger = LoggerCollection(loggers)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
@ -236,5 +239,3 @@ if __name__ == '__main__':
args)
logging.info('Training finished successfully')

View file

@ -2,9 +2,9 @@ from setuptools import setup, find_packages
setup(
name='se3-transformer',
packages=find_packages(),
packages=find_packages(exclude=['tests']),
include_package_data=True,
version='1.0.0',
version='1.1.0',
description='PyTorch + DGL implementation of SE(3)-Transformers',
author='Alexandre Milesi',
author_email='alexandrem@nvidia.com',

View file

@ -25,7 +25,11 @@ import torch
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
if __package__ is None or __package__ == '':
from utils import get_random_graph, assign_relative_pos, get_max_diff, rot
else:
from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
TOL = 1e-3

View file

@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided
## Graph Neural Networks
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
## NVIDIA support