[SE3Transformer/PyT] Public release and CI

This commit is contained in:
Alexandre Milesi (Engrg-Hardware 1) 2021-08-20 11:57:31 +02:00 committed by Andrei Shumak
parent b6e5ebdbc9
commit 85c54b5c36
41 changed files with 3961 additions and 0 deletions

View file

@ -0,0 +1,123 @@
.Trash-0
.git
data/
.DS_Store
*wandb/
*.pt
*.swp
# added by FAFU
.idea/
cache/
downloaded/
*.lprof
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
**/benchmark
**/results
*.pkl
*.log

View file

@ -0,0 +1,121 @@
data/
.DS_Store
*wandb/
*.pt
*.swp
# added by FAFU
.idea/
cache/
downloaded/
*.lprof
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
**/benchmark
**/results
*.pkl
*.log

View file

@ -0,0 +1,58 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
# run docker daemon with --default-runtime=nvidia for GPU detection during build
# multistage build for DGL with CUDA and FP16
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
FROM ${FROM_IMAGE_NAME} AS dgl_builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y git build-essential python3-dev make cmake \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /dgl
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
WORKDIR build
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
RUN make -j8
FROM ${FROM_IMAGE_NAME}
RUN rm -rf /workspace/*
WORKDIR /workspace/se3-transformer
# copy built DGL and install it
COPY --from=dgl_builder /dgl ./dgl
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
ADD requirements.txt .
RUN pip install --no-cache-dir --upgrade --pre pip
RUN pip install --no-cache-dir -r requirements.txt
ADD . .
ENV DGLBACKEND=pytorch
ENV OMP_NUM_THREADS=1

View file

@ -0,0 +1,7 @@
Copyright 2021 NVIDIA CORPORATION & AFFILIATES
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,7 @@
SE(3)-Transformer PyTorch
This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
licensed under the MIT License.
This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
licensed under the MIT License.

View file

@ -0,0 +1,580 @@
# SE(3)-Transformers For PyTorch
This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
## Table Of Contents
- [Model overview](#model-overview)
* [Model architecture](#model-architecture)
* [Default configuration](#default-configuration)
* [Feature support matrix](#feature-support-matrix)
* [Features](#features)
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
* [Enabling TF32](#enabling-tf32)
* [Glossary](#glossary)
- [Setup](#setup)
* [Requirements](#requirements)
- [Quick Start Guide](#quick-start-guide)
- [Advanced](#advanced)
* [Scripts and sample code](#scripts-and-sample-code)
* [Parameters](#parameters)
* [Command-line options](#command-line-options)
* [Getting the data](#getting-the-data)
* [Dataset guidelines](#dataset-guidelines)
* [Multi-dataset](#multi-dataset)
* [Training process](#training-process)
* [Inference process](#inference-process)
- [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training performance benchmark](#training-performance-benchmark)
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
* [Training stability test](#training-stability-test)
* [Training performance results](#training-performance-results)
* [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
* [Inference performance results](#inference-performance-results)
* [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
- [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
## Model overview
The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
The model is based on the following publications:
- [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
- [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
- [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
The main differences between this implementation of SE(3)-Transformers and the official one are the following:
- Training and inference support for multiple GPUs
- Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
- The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
- Significantly increased throughput
- Significantly reduced memory consumption
- The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
- The use of equivariant normalization between attention layers is an option (`--norm`), off by default
- The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [ClebschGordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
### Model architecture
The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
![Model high-level architecture](./images/se3-transformer.png)
### Default configuration
SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
The following features were implemented in this model:
- Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
- Data-parallel multi-GPU training (DDP)
- Mixed precision training (autocast, gradient scaling)
- Gradient accumulation
- Model checkpointing
The following performance optimizations were implemented in this model:
**General optimizations**
- The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
- The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
- The Clebsch-Gordon coefficients are cached in RAM
**Tensor Field Network optimizations**
- The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
**Self-attention optimizations**
- Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
- Graph operations for different output degrees may be fused together if conditions are met
**Normalization optimizations**
- The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
- Number of layers: 7
- Number of degrees: 4
- Number of channels: 32
- Number of attention heads: 8
- Channels division: 2
- Use of equivariant normalization: true
- Use of layer normalization: true
- Pooling: max
### Feature support matrix
This model supports the following features::
| Feature | SE(3)-Transformer
|-----------------------|--------------------------
|Automatic mixed precision (AMP) | Yes
|Distributed data parallel (DDP) | Yes
#### Features
**Distributed data parallel (DDP)**
[DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
**Automatic Mixed Precision (AMP)**
This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
### Mixed precision training
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
1. Porting the model to use the FP16 data type where appropriate.
2. Adding loss scaling to preserve small gradient values.
AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
For information about:
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
#### Enabling mixed precision
Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
#### Enabling TF32
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
### Glossary
**Degree (type)**
In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
Some common examples include:
- Degree 0: 1D scalars invariant to rotation
- Degree 1: 3D vectors that rotate according to 3D rotation matrices
- Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
**Fiber**
A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
**Multiplicity**
The multiplicity of a feature of a given type is the number of channels of this feature.
**Tensor Field Network**
A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
**Equivariance**
[Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
## Setup
The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
### Requirements
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- PyTorch 21.07+ NGC container
- Supported GPUs:
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
## Quick Start Guide
To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
1. Clone the repository.
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
```
2. Build the `se3-transformer` PyTorch NGC container.
```
docker build -t se3-transformer .
```
3. Start an interactive session in the NGC container to run training/inference.
```
mkdir -p results
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
```
4. Start training.
```
bash scripts/train.sh
```
5. Start inference/predictions.
```
bash scripts/predict.sh
```
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
## Advanced
The following sections provide greater details of the dataset, running training and inference, and the training results.
### Scripts and sample code
In the root directory, the most important files are:
- `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
- `requirements.txt`: set of extra requirements to run SE(3)-Transformers
- `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
- `se3_transformer/model/layers/`: directory containing model architecture layers
- `se3_transformer/model/transformer.py`: main Transformer module
- `se3_transformer/model/basis.py`: logic for computing bases matrices
- `se3_transformer/runtime/training.py`: training script, to be run as a python module
- `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
- `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
- `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
### Parameters
The complete list of the available parameters for the `training.py` script contains:
**General**
- `--epochs`: Number of training epochs (default: `100` for single-GPU)
- `--batch_size`: Batch size (default: `240`)
- `--seed`: Set a seed globally (default: `None`)
- `--num_workers`: Number of dataloading workers (default: `8`)
- `--amp`: Use Automatic Mixed Precision (default `false`)
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
- `--silent`: Minimize stdout output (default: `false`)
**Paths**
- `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
- `--log_dir`: Directory where the results logs should be saved (default: `/results`)
- `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
- `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
**Optimizer**
- `--optimizer`: Optimizer to use (default: `adam`)
- `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
- `--momentum`: Momentum to use (default: `0.9`)
- `--weight_decay`: Weight decay to use (default: `0.1`)
**QM9 dataset**
- `--task`: Regression task to train on (default: `homo`)
- `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
**Model architecture**
- `--num_layers`: Number of stacked Transformer layers (default: `7`)
- `--num_heads`: Number of heads in self-attention (default: `8`)
- `--channels_div`: Channels division before feeding to attention layer (default: `2`)
- `--pooling`: Type of graph pooling (default: `max`)
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
- `--num_channels`: Number of channels for the hidden features (default: `32`)
### Command-line options
To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
### Dataset guidelines
#### Demo dataset
The SE(3)-Transformer was trained on the QM9 dataset.
The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
As input features, we use:
- Node features (6D):
- One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
- Number of protons of each atom (1D)
- Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
- The relative positions between adjacent nodes (atoms)
#### Custom datasets
To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
Your custom collate function should return a tuple with:
- A (batched) DGLGraph object
- A dictionary of node features ({{degree}: tensor})
- A dictionary of edge features ({{degree}: tensor})
- (Optional) Precomputed bases as a dictionary
- Labels as a tensor
You can then modify the `training.py` and `inference.py` scripts to use your new data module.
### Training process
The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
**Logs**
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
**Checkpoints**
The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
`--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
**Evaluation**
The evaluation metric is the Mean Absolute Error (MAE).
`--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
**Automatic Mixed Precision**
To enable Mixed Precision training, add the `--amp` flag.
**Multi-GPU and multi-node**
The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
For example, to train on all available GPUs with AMP:
```
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
```
### Inference process
Inference can be run by using the `se3_transformer.runtime.inference` python module.
The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
## Performance
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
### Benchmarking
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
#### Training performance benchmark
To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
#### Inference performance benchmark
To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
### Results
The following sections provide details on how we achieved our performance and accuracy in training and inference.
#### Training accuracy results
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
| 8 | 240 | 0.03417 | 0.03424 | 35min | 27min | 1.30x |
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
| 8 | 240 | 0.03380 | 0.03495 | 1h08min | 44min | 1.55x |
#### Training performance results
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
#### Inference performance results
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
FP16
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
| 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
| 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
TF32
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
| 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
| 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
FP16
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
| 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
| 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
FP32
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
| 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
| 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
## Release notes
### Changelog
August 2021
- Initial release
### Known issues
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View file

@ -0,0 +1,2 @@
e3nn==0.3.3
wandb==0.12.0

View file

@ -0,0 +1,15 @@
#!/usr/bin/env bash
# Script to benchmark inference performance, without bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--use_layer_norm \
--norm \
--task homo \
--seed 42 \
--benchmark

View file

@ -0,0 +1,18 @@
#!/usr/bin/env bash
# Script to benchmark single-GPU training performance, with bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs 6 \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--task homo \
--precompute_bases \
--seed 42 \
--benchmark

View file

@ -0,0 +1,19 @@
#!/usr/bin/env bash
# Script to benchmark multi-GPU training performance, with bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs 6 \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--task homo \
--precompute_bases \
--seed 42 \
--benchmark

View file

@ -0,0 +1,19 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.inference \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--use_layer_norm \
--norm \
--load_ckpt_path model_qm9.pth \
--task "$TASK"

View file

@ -0,0 +1,25 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
NUM_EPOCHS=${3:-100}
LEARNING_RATE=${4:-0.002}
WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \
--weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--precompute_bases \
--seed 42 \
--task "$TASK"

View file

@ -0,0 +1,27 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
NUM_EPOCHS=${3:-300}
LEARNING_RATE=${4:-0.004}
WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \
--min_lr 0.0001 \
--weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--precompute_bases \
--seed 42 \
--task "$TASK"

View file

@ -0,0 +1 @@
from .qm9 import QM9DataModule

View file

@ -0,0 +1,63 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch.distributed as dist
from abc import ABC
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from se3_transformer.runtime.utils import get_local_rank
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
# Classic or distributed dataloader depending on the context
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
class DataModule(ABC):
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
def __init__(self, **dataloader_kwargs):
super().__init__()
if get_local_rank() == 0:
self.prepare_data()
# Wait until rank zero has prepared the data (download, preprocessing, ...)
if dist.is_initialized():
dist.barrier(device_ids=[get_local_rank()])
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
self.ds_train, self.ds_val, self.ds_test = None, None, None
def prepare_data(self):
""" Method called only once per node. Put here any downloading or preprocessing """
pass
def train_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
def val_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
def test_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)

View file

@ -0,0 +1,171 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Tuple
import dgl
import pathlib
import torch
from dgl.data import QM9EdgeDataset
from dgl import DGLGraph
from torch import Tensor
from torch.utils.data import random_split, DataLoader, Dataset
from tqdm import tqdm
from se3_transformer.data_loading.data_module import DataModule
from se3_transformer.model.basis import get_basis
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
x = qm9_graph.ndata['pos']
src, dst = qm9_graph.edges()
rel_pos = x[dst] - x[src]
return rel_pos
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
len_full = len(full_dataset)
len_train = 100_000
len_test = int(0.1 * len_full)
len_val = len_full - len_train - len_test
return len_train, len_val, len_test
class QM9DataModule(DataModule):
"""
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
This includes all the molecules from QM9 except the ones that are uncharacterized.
"""
NODE_FEATURE_DIM = 6
EDGE_FEATURE_DIM = 4
def __init__(self,
data_dir: pathlib.Path,
task: str = 'homo',
batch_size: int = 240,
num_workers: int = 8,
num_degrees: int = 4,
amp: bool = False,
precompute_bases: bool = False,
**kwargs):
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
self.amp = amp
self.task = task
self.batch_size = batch_size
self.num_degrees = num_degrees
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
if precompute_bases:
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size, **qm9_kwargs)
else:
full_dataset = QM9EdgeDataset(**qm9_kwargs)
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
generator=torch.Generator().manual_seed(0))
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
self.targets_mean = train_targets.mean()
self.targets_std = train_targets.std()
def prepare_data(self):
# Download the QM9 preprocessed data
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
def _collate(self, samples):
graphs, y, *bases = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
# get node features
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
if bases:
# collate bases
all_bases = {
key: torch.cat([b[key] for b in bases[0]], dim=0)
for key in bases[0][0].keys()
}
return batched_graph, node_feats, edge_feats, all_bases, targets
else:
return batched_graph, node_feats, edge_feats, targets
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("QM9 dataset")
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
help='Regression task to train on')
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
help='Precompute bases at the beginning of the script during dataset initialization,'
' instead of computing them at the beginning of each forward pass.')
return parent_parser
def __repr__(self):
return f'QM9({self.task})'
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
def __init__(self, bases_kwargs: dict, batch_size: int, *args, **kwargs):
"""
:param bases_kwargs: Arguments to feed the bases computation function
:param batch_size: Batch size to use when iterating over the dataset for computing bases
"""
self.bases_kwargs = bases_kwargs
self.batch_size = batch_size
self.bases = None
super().__init__(*args, **kwargs)
def load(self):
super().load()
# Iterate through the dataset and compute bases (pairwise only)
# Potential improvement: use multi-GPU and reduction
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size,
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
bases = []
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
disable=get_local_rank() != 0):
rel_pos = _get_relative_pos(graph)
# Compute the bases with the GPU but convert the result to CPU to store in RAM
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
self.bases = bases # Assign at the end so that __getitem__ isn't confused
def __getitem__(self, idx: int):
graph, label = super().__getitem__(idx)
if self.bases:
bases_idx = idx // self.batch_size
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
self.bases[bases_idx].items()}
else:
return graph, label

View file

@ -0,0 +1,2 @@
from .transformer import SE3Transformer, SE3TransformerPooled
from .fiber import Fiber

View file

@ -0,0 +1,178 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from functools import lru_cache
from typing import Dict, List
import e3nn.o3 as o3
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.runtime.utils import degree_to_dim
@lru_cache(maxsize=None)
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
@lru_cache(maxsize=None)
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
all_cb = []
for d_in in range(max_degree + 1):
for d_out in range(max_degree + 1):
K_Js = []
for J in range(abs(d_in - d_out), d_in + d_out + 1):
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
all_cb.append(K_Js)
return all_cb
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
all_degrees = list(range(2 * max_degree + 1))
with nvtx_range('spherical harmonics'):
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
@torch.jit.script
def get_basis_script(max_degree: int,
use_pad_trick: bool,
spherical_harmonics: List[Tensor],
clebsch_gordon: List[List[Tensor]],
amp: bool) -> Dict[str, Tensor]:
"""
Compute pairwise bases matrices for degrees up to max_degree
:param max_degree: Maximum input or output degree
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
:param spherical_harmonics: List of computed spherical harmonics
:param clebsch_gordon: List of computed CB-coefficients
:param amp: When true, return bases in FP16 precision
"""
basis = {}
idx = 0
# Double for loop instead of product() because of JIT script
for d_in in range(max_degree + 1):
for d_out in range(max_degree + 1):
key = f'{d_in},{d_out}'
K_Js = []
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
Q_J = clebsch_gordon[idx][freq_idx]
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
if amp:
basis[key] = basis[key].half()
if use_pad_trick:
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
idx += 1
return basis
@torch.jit.script
def update_basis_with_fused(basis: Dict[str, Tensor],
max_degree: int,
use_pad_trick: bool,
fully_fused: bool) -> Dict[str, Tensor]:
""" Update the basis dict with partially and optionally fully fused bases """
num_edges = basis['0,0'].shape[0]
device = basis['0,0'].device
dtype = basis['0,0'].dtype
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
# Fused per output degree
for d_out in range(max_degree + 1):
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_in in range(max_degree + 1):
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
acc_d += degree_to_dim(d_in)
acc_f += degree_to_dim(min(d_out, d_in))
basis[f'out{d_out}_fused'] = basis_fused
# Fused per input degree
for d_in in range(max_degree + 1):
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_out in range(max_degree + 1):
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
acc_d += degree_to_dim(d_out)
acc_f += degree_to_dim(min(d_out, d_in))
basis[f'in{d_in}_fused'] = basis_fused
if fully_fused:
# Fully fused
# Double sum this way because of JIT script
sum_freq = sum([
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
])
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_out in range(max_degree + 1):
b = basis[f'out{d_out}_fused']
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
:degree_to_dim(d_out)]
acc_f += b.shape[2]
acc_d += degree_to_dim(d_out)
basis['fully_fused'] = basis_fused
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
return basis
def get_basis(relative_pos: Tensor,
max_degree: int = 4,
compute_gradients: bool = False,
use_pad_trick: bool = False,
amp: bool = False) -> Dict[str, Tensor]:
with nvtx_range('spherical harmonics'):
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
with nvtx_range('CB coefficients'):
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
with torch.autograd.set_grad_enabled(compute_gradients):
with nvtx_range('bases'):
basis = get_basis_script(max_degree=max_degree,
use_pad_trick=use_pad_trick,
spherical_harmonics=spherical_harmonics,
clebsch_gordon=clebsch_gordon,
amp=amp)
return basis

View file

@ -0,0 +1,144 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from collections import namedtuple
from itertools import product
from typing import Dict
import torch
from torch import Tensor
from se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
class Fiber(dict):
"""
Describes the structure of some set of features.
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
Type-0 features: invariant scalars
Type-1 features: equivariant 3D vectors
Type-2 features: equivariant symmetric traceless matrices
...
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
The 'multiplicity' or 'number of channels' is the number of features of a given type.
This class puts together all the degrees and their multiplicities in order to describe
the inputs, outputs or hidden features of SE3 layers.
"""
def __init__(self, structure):
if isinstance(structure, dict):
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
elif not isinstance(structure[0], FiberEl):
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
self.structure = structure
super().__init__({d: m for d, m in self.structure})
@property
def degrees(self):
return sorted([t.degree for t in self.structure])
@property
def channels(self):
return [self[d] for d in self.degrees]
@property
def num_features(self):
""" Size of the resulting tensor if all features were concatenated together """
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
@staticmethod
def create(num_degrees: int, num_channels: int):
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
@staticmethod
def from_features(feats: Dict[str, Tensor]):
""" Infer the Fiber structure from a feature dict """
structure = {}
for k, v in feats.items():
degree = int(k)
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
assert v.shape[-1] == degree_to_dim(degree)
structure[degree] = v.shape[-2]
return Fiber(structure)
def __getitem__(self, degree: int):
""" fiber[degree] returns the multiplicity for this degree """
return dict(self.structure).get(degree, 0)
def __iter__(self):
""" Iterate over namedtuples (degree, channels) """
return iter(self.structure)
def __mul__(self, other):
"""
If other in an int, multiplies all the multiplicities by other.
If other is a fiber, returns the cartesian product.
"""
if isinstance(other, Fiber):
return product(self.structure, other.structure)
elif isinstance(other, int):
return Fiber({t.degree: t.channels * other for t in self.structure})
def __add__(self, other):
"""
If other in an int, add other to all the multiplicities.
If other is a fiber, add the multiplicities of the fibers together.
"""
if isinstance(other, Fiber):
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
elif isinstance(other, int):
return Fiber({t.degree: t.channels + other for t in self.structure})
def __repr__(self):
return str(self.structure)
@staticmethod
def combine_max(f1, f2):
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
new_dict = dict(f1.structure)
for k, m in f2.structure:
new_dict[k] = max(new_dict.get(k, 0), m)
return Fiber(list(new_dict.items()))
@staticmethod
def combine_selectively(f1, f2):
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
# only use orders which occur in fiber f1
new_dict = dict(f1.structure)
for k in f1.degrees:
if k in f2.degrees:
new_dict[k] += f2[k]
return Fiber(list(new_dict.items()))
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
self.degrees]
fibers = torch.cat(fibers, -1)
return fibers

View file

@ -0,0 +1,5 @@
from .linear import LinearSE3
from .norm import NormSE3
from .pooling import GPooling
from .convolution import ConvSE3
from .attention import AttentionBlockSE3

View file

@ -0,0 +1,180 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from dgl.ops import edge_softmax
from torch import Tensor
from typing import Dict, Optional, Union
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range
class AttentionSE3(nn.Module):
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
def __init__(
self,
num_heads: int,
key_fiber: Fiber,
value_fiber: Fiber
):
"""
:param num_heads: Number of attention heads
:param key_fiber: Fiber for the keys (and also for the queries)
:param value_fiber: Fiber for the values
"""
super().__init__()
self.num_heads = num_heads
self.key_fiber = key_fiber
self.value_fiber = value_fiber
def forward(
self,
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
query: Dict[str, Tensor], # node features
graph: DGLGraph
):
with nvtx_range('AttentionSE3'):
with nvtx_range('reshape keys and queries'):
if isinstance(key, Tensor):
# case where features of all types are fused
key = key.reshape(key.shape[0], self.num_heads, -1)
# need to reshape queries that way to keep the same layout as keys
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
else:
# features are not fused, need to fuse and reshape them
key = self.key_fiber.to_attention_heads(key, self.num_heads)
query = self.key_fiber.to_attention_heads(query, self.num_heads)
with nvtx_range('attention dot product + softmax'):
# Compute attention weights (softmax of inner product between key and query)
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
edge_weights /= np.sqrt(self.key_fiber.num_features)
edge_weights = edge_softmax(graph, edge_weights)
edge_weights = edge_weights[..., None, None]
with nvtx_range('weighted sum'):
if isinstance(value, Tensor):
# features of all types are fused
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
weights = edge_weights * v
feat_out = dgl.ops.copy_e_sum(graph, weights)
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
out = unfuse_features(feat_out, self.value_fiber.degrees)
else:
out = {}
for degree, channels in self.value_fiber:
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
degree_to_dim(degree))
weights = edge_weights * v
res = dgl.ops.copy_e_sum(graph, weights)
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
return out
class AttentionBlockSE3(nn.Module):
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Optional[Fiber] = None,
num_heads: int = 4,
channels_div: int = 2,
use_layer_norm: bool = False,
max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
**kwargs
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param num_heads: Number of attention heads
:param channels_div: Divide the channels by this integer for computing values
:param use_layer_norm: Apply layer normalization between MLP layers
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
"""
super().__init__()
if fiber_edge is None:
fiber_edge = Fiber({})
self.fiber_in = fiber_in
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
# (queries are merely projected, hence degrees have to match input)
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
allow_fused_output=True)
self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def forward(
self,
node_features: Dict[str, Tensor],
edge_features: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range('AttentionBlockSE3'):
with nvtx_range('keys / values'):
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
key, value = self._get_key_value_from_fused(fused_key_value)
with nvtx_range('queries'):
query = self.to_query(node_features)
z = self.attention(value, key, query, graph)
z_concat = aggregate_residual(node_features, z, 'cat')
return self.project(z_concat)
def _get_key_value_from_fused(self, fused_key_value):
# Extract keys and queries features from fused features
if isinstance(fused_key_value, Tensor):
# Previous layer was a fully fused convolution
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
else:
key, value = {}, {}
for degree, feat in fused_key_value.items():
if int(degree) in self.fiber_in.degrees:
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
else:
value[degree] = feat
return key, value

View file

@ -0,0 +1,335 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from enum import Enum
from itertools import product
from typing import Dict
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):
"""
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
A higher level means faster training, but also more memory usage.
If you are tight on memory and want to feed large inputs to the network, choose a low value.
If you want to train fast, choose a high value.
Recommended value is FULL with AMP.
Fully fused TFN convolutions requirements:
- all input channels are the same
- all output channels are the same
- input degrees span the range [0, ..., max_degree]
- output degrees span the range [0, ..., max_degree]
Partially fused TFN convolutions requirements:
* For fusing by output degree:
- all input channels are the same
- input degrees span the range [0, ..., max_degree]
* For fusing by input degree:
- all output channels are the same
- output degrees span the range [0, ..., max_degree]
Original TFN pairwise convolutions: no requirements
"""
FULL = 2
PARTIAL = 1
NONE = 0
class RadialProfile(nn.Module):
"""
Radial profile function.
Outputs weights used to weigh basis matrices in order to get convolution kernels.
In TFN notation: $R^{l,k}$
In SE(3)-Transformer notation: $\phi^{l,k}$
Note:
In the original papers, this function only depends on relative node distances ||x||.
Here, we allow this function to also take as input additional invariant edge features.
This does not break equivariance and adds expressive power to the model.
Diagram:
invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights
"""
def __init__(
self,
num_freq: int,
channels_in: int,
channels_out: int,
edge_dim: int = 1,
mid_dim: int = 32,
use_layer_norm: bool = False
):
"""
:param num_freq: Number of frequencies
:param channels_in: Number of input channels
:param channels_out: Number of output channels
:param edge_dim: Number of invariant edge features (input to the radial function)
:param mid_dim: Size of the hidden MLP layers
:param use_layer_norm: Apply layer normalization between MLP layers
"""
super().__init__()
modules = [
nn.Linear(edge_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
]
self.net = nn.Sequential(*[m for m in modules if m is not None])
def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class VersatileConvSE3(nn.Module):
"""
Building block for TFN convolutions.
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
"""
def __init__(self,
freq_sum: int,
channels_in: int,
channels_out: int,
edge_dim: int,
use_layer_norm: bool,
fuse_level: ConvSE3FuseLevel):
super().__init__()
self.freq_sum = freq_sum
self.channels_out = channels_out
self.channels_in = channels_in
self.fuse_level = fuse_level
self.radial_func = RadialProfile(num_freq=freq_sum,
channels_in=channels_in,
channels_out=channels_out,
edge_dim=edge_dim,
use_layer_norm=use_layer_norm)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
with nvtx_range(f'VersatileConvSE3'):
num_edges = features.shape[0]
in_dim = features.shape[2]
with nvtx_range(f'RadialProfile'):
radial_weights = self.radial_func(invariant_edge_feats) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
out_dim = basis.shape[-1]
if self.fuse_level != ConvSE3FuseLevel.FULL:
out_dim += out_dim % 2 - 1 # Account for padded basis
basis_view = basis.view(num_edges, in_dim, -1)
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
return (radial_weights @ tmp)[:, :, :out_dim]
else:
# k = l = 0 non-fused case
return radial_weights @ features
class ConvSE3(nn.Module):
"""
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
Features of different degrees interact together to produce output features.
Note 1:
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
done, and the returned features will be edge features instead of node features.
Note 2:
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
Input edge features are concatenated with input source node features before the kernel is applied.
"""
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
pool: bool = True,
use_layer_norm: bool = False,
self_interaction: bool = False,
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param pool: If True, compute final node features by averaging incoming edge features
:param use_layer_norm: Apply layer normalization between MLP layers
:param self_interaction: Apply self-interaction of nodes
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
:param allow_fused_output: Allow the module to output a fused representation of features
"""
super().__init__()
self.pool = pool
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.self_interaction = self_interaction
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
# channels_in: account for the concatenation of edge features
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
channels_out_set = set([f.channels for f in self.fiber_out])
unique_channels_in = (len(channels_in_set) == 1)
unique_channels_out = (len(channels_out_set) == 1)
degrees_up_to_max = list(range(max_degree + 1))
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Single fused convolution
self.used_fuse_level = ConvSE3FuseLevel.FULL
sum_freq = sum([
degree_to_dim(min(d_in, d_out))
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
])
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
# Convolutions fused per output degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_out = nn.ModuleDict()
for d_out, c_out in fiber_out:
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Convolutions fused per input degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
else:
# Use pairwise TFN convolutions
self.used_fuse_level = ConvSE3FuseLevel.NONE
self.conv = nn.ModuleDict()
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
dict_key = f'{degree_in},{degree_out}'
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
sum_freq = degree_to_dim(min(degree_in, degree_out))
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
fuse_level=self.used_fuse_level, **common_args)
if self_interaction:
self.to_kernel_self = nn.ParameterDict()
for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]:
self.to_kernel_self[str(degree_out)] = nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
def forward(
self,
node_feats: Dict[str, Tensor],
edge_feats: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range(f'ConvSE3'):
invariant_edge_feats = edge_feats['0'].squeeze(-1)
src, dst = graph.edges()
out = {}
in_features = []
# Fetch all input features from edge and node features
for degree_in in self.fiber_in.degrees:
src_node_features = node_feats[str(degree_in)][src]
if degree_in > 0 and str(degree_in) in edge_feats:
# Handle edge features of any type by concatenating them to node features
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
in_features.append(src_node_features)
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
in_features_fused = torch.cat(in_features, dim=-1)
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
basis[f'out{degree_out}_fused'])
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
basis[f'in{degree_in}_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
else:
# Fallback to pairwise TFN convolutions
for degree_out in self.fiber_out.degrees:
out_feature = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
dict_key = f'{degree_in},{degree_out}'
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
basis.get(dict_key, None))
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:
if self.self_interaction and str(degree_out) in self.to_kernel_self:
with nvtx_range(f'self interaction'):
dst_features = node_feats[str(degree_out)][dst]
kernel_self = self.to_kernel_self[str(degree_out)]
out[str(degree_out)] += kernel_self @ dst_features
if self.pool:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_sum(graph, out)
return out

View file

@ -0,0 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module):
"""
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
Maps a fiber to a fiber with the same degrees (channels may be different).
No interaction between degrees, but interaction between channels.
type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels)
type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels)
:
type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels)
"""
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
super().__init__()
self.weights = nn.ParameterDict({
str(degree_out): nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
for degree_out, channels_out in fiber_out
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
return {
degree: self.weights[degree] @ features[degree]
for degree, weight in self.weights.items()
}

View file

@ -0,0 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module):
"""
Norm-based SE(3)-equivariant nonlinearity.
> feature_norm > LayerNorm() > ReLU()
feature_in * > feature_out
> feature_phase
"""
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__()
self.fiber = fiber
self.nonlinearity = nonlinearity
if len(set(fiber.channels)) == 1:
# Fuse all the layer normalizations into a group normalization
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
else:
# Use multiple layer normalizations
self.layer_norms = nn.ModuleDict({
str(degree): nn.LayerNorm(channels)
for degree, channels in fiber
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
with nvtx_range('NormSE3'):
output = {}
if hasattr(self, 'group_norm'):
# Compute per-degree norms of features
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
for d in self.fiber.degrees]
fused_norms = torch.cat(norms, dim=-2)
# Transform the norms only
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
# Scale features to the new norms
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
output[str(d)] = features[str(d)] / norm * new_norm
else:
for degree, feat in features.items():
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
output[degree] = new_norm * feat / norm
return output

View file

@ -0,0 +1,53 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict, Literal
import torch.nn as nn
from dgl import DGLGraph
from dgl.nn.pytorch import AvgPooling, MaxPooling
from torch import Tensor
class GPooling(nn.Module):
"""
Graph max/average pooling on a given feature type.
The average can be taken for any feature type, and equivariance will be maintained.
The maximum can only be taken for invariant features (type 0).
If you want max-pooling for type > 0 features, look into Vector Neurons.
"""
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
"""
:param feat_type: Feature type to pool
:param pool: Type of pooling: max or avg
"""
super().__init__()
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
self.feat_type = feat_type
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
pooled = self.pool(graph, features[str(self.feat_type)])
return pooled.squeeze(dim=-1)

View file

@ -0,0 +1,222 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
from typing import Optional, Literal, Dict
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential):
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
def forward(self, input, *args, **kwargs):
for module in self:
input = module(input, *args, **kwargs)
return input
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
""" Add relative positions to existing edge features """
edge_features = edge_features.copy() if edge_features else {}
r = relative_pos.norm(dim=-1, keepdim=True)
if '0' in edge_features:
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
else:
edge_features['0'] = r[..., None]
return edge_features
class SE3Transformer(nn.Module):
def __init__(self,
num_layers: int,
fiber_in: Fiber,
fiber_hidden: Fiber,
fiber_out: Fiber,
num_heads: int,
channels_div: int,
fiber_edge: Fiber = Fiber({}),
return_type: Optional[int] = None,
pooling: Optional[Literal['avg', 'max']] = None,
norm: bool = True,
use_layer_norm: bool = True,
tensor_cores: bool = False,
low_memory: bool = False,
**kwargs):
"""
:param num_layers: Number of attention layers
:param fiber_in: Input fiber description
:param fiber_hidden: Hidden fiber description
:param fiber_out: Output fiber description
:param fiber_edge: Input edge fiber description
:param num_heads: Number of attention heads
:param channels_div: Channels division before feeding to attention layer
:param return_type: Return only features of this type
:param pooling: 'avg' or 'max' graph pooling before MLP layers
:param norm: Apply a normalization layer after each attention block
:param use_layer_norm: Apply layer normalization between MLP layers
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
:param low_memory: If True, will use slower ops that use less memory
"""
super().__init__()
self.num_layers = num_layers
self.fiber_edge = fiber_edge
self.num_heads = num_heads
self.channels_div = channels_div
self.return_type = return_type
self.pooling = pooling
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
self.tensor_cores = tensor_cores
self.low_memory = low_memory
if low_memory and not tensor_cores:
logging.warning('Low memory mode will have no effect with no Tensor Cores')
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
graph_modules = []
for i in range(num_layers):
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_hidden,
fiber_edge=fiber_edge,
num_heads=num_heads,
channels_div=channels_div,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=fuse_level))
if norm:
graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_out=fiber_out,
fiber_edge=fiber_edge,
self_interaction=True,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree))
self.graph_modules = Sequential(*graph_modules)
if pooling is not None:
assert return_type is not None, 'return_type must be specified when pooling'
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
edge_feats: Optional[Dict[str, Tensor]] = None,
basis: Optional[Dict[str, Tensor]] = None):
# Compute bases in case they weren't precomputed as part of the data loading
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
use_pad_trick=self.tensor_cores and not self.low_memory,
amp=torch.is_autocast_enabled())
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.tensor_cores and not self.low_memory)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph)
if self.return_type is not None:
return node_feats[str(self.return_type)]
return node_feats
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--num_layers', type=int, default=7,
help='Number of stacked Transformer layers')
parser.add_argument('--num_heads', type=int, default=8,
help='Number of heads in self-attention')
parser.add_argument('--channels_div', type=int, default=2,
help='Channels division before feeding to attention layer')
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
help='Type of graph pooling')
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply a normalization layer after each attention block')
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply layer normalization between MLP layers')
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
help='If true, will use fused ops that are slower but that use less memory '
'(expect 25 percent less memory). '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
return parser
class SE3TransformerPooled(nn.Module):
def __init__(self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
num_degrees: int,
num_channels: int,
output_dim: int,
**kwargs):
super().__init__()
kwargs['pooling'] = kwargs['pooling'] or 'max'
self.transformer = SE3Transformer(
fiber_in=fiber_in,
fiber_hidden=Fiber.create(num_degrees, num_channels),
fiber_out=fiber_out,
fiber_edge=fiber_edge,
return_type=0,
**kwargs
)
n_out_features = fiber_out.num_features
self.mlp = nn.Sequential(
nn.Linear(n_out_features, n_out_features),
nn.ReLU(),
nn.Linear(n_out_features, output_dim)
)
def forward(self, graph, node_feats, edge_feats, basis=None):
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
y = self.mlp(feats).squeeze(-1)
return y
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("Model architecture")
SE3Transformer.add_argparse_args(parser)
parser.add_argument('--num_degrees',
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
type=int, default=4)
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
return parent_parser

View file

@ -0,0 +1,70 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import pathlib
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)

View file

@ -0,0 +1,160 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import time
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import torch
from se3_transformer.runtime.loggers import Logger
from se3_transformer.runtime.metrics import MeanAbsoluteError
class BaseCallback(ABC):
def on_fit_start(self, optimizer, args):
pass
def on_fit_end(self):
pass
def on_epoch_end(self):
pass
def on_batch_start(self):
pass
def on_validation_step(self, input, target, pred):
pass
def on_validation_end(self, epoch=None):
pass
def on_checkpoint_load(self, checkpoint):
pass
def on_checkpoint_save(self, checkpoint):
pass
class LRSchedulerCallback(BaseCallback):
def __init__(self, logger: Optional[Logger] = None):
self.logger = logger
self.scheduler = None
@abstractmethod
def get_scheduler(self, optimizer, args):
pass
def on_fit_start(self, optimizer, args):
self.scheduler = self.get_scheduler(optimizer, args)
def on_checkpoint_load(self, checkpoint):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
def on_checkpoint_save(self, checkpoint):
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
def on_epoch_end(self):
if self.logger is not None:
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
self.scheduler.step()
class QM9MetricCallback(BaseCallback):
""" Logs the rescaled mean absolute error for QM9 regression tasks """
def __init__(self, logger, targets_std, prefix=''):
self.mae = MeanAbsoluteError()
self.logger = logger
self.targets_std = targets_std
self.prefix = prefix
self.best_mae = float('inf')
def on_validation_step(self, input, target, pred):
self.mae(pred.detach(), target.detach())
def on_validation_end(self, epoch=None):
mae = self.mae.compute() * self.targets_std
logging.info(f'{self.prefix} MAE: {mae}')
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
self.best_mae = min(self.best_mae, mae)
def on_fit_end(self):
if self.best_mae != float('inf'):
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
class QM9LRSchedulerCallback(LRSchedulerCallback):
def __init__(self, logger, epochs):
super().__init__(logger)
self.epochs = epochs
def get_scheduler(self, optimizer, args):
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
class PerformanceCallback(BaseCallback):
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
self.batch_size = batch_size
self.warmup_epochs = warmup_epochs
self.epoch = 0
self.timestamps = []
self.mode = mode
self.logger = logger
def on_batch_start(self):
if self.epoch >= self.warmup_epochs:
self.timestamps.append(time.time() * 1000.0)
def _log_perf(self):
stats = self.process_performance_stats()
for k, v in stats.items():
logging.info(f'performance {k}: {v}')
self.logger.log_metrics(stats)
def on_epoch_end(self):
self.epoch += 1
def on_fit_end(self):
if self.epoch > self.warmup_epochs:
self._log_perf()
self.timestamps = []
def process_performance_stats(self):
timestamps = np.asarray(self.timestamps)
deltas = np.diff(timestamps)
throughput = (self.batch_size / deltas).mean()
stats = {
f"throughput_{self.mode}": throughput,
f"latency_{self.mode}_mean": deltas.mean(),
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
}
for level in [90, 95, 99]:
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
return stats

View file

@ -0,0 +1,325 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import collections
import itertools
import math
import os
import pathlib
import re
import pynvml
class Device:
# assumes nvml returns list of 64 bit ints
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
def __init__(self, device_idx):
super().__init__()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
def get_name(self):
return pynvml.nvmlDeviceGetName(self.handle)
def get_uuid(self):
return pynvml.nvmlDeviceGetUUID(self.handle)
def get_cpu_affinity(self):
affinity_string = ""
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
# assume nvml returns list of 64 bit ints
affinity_string = "{:064b}".format(j) + affinity_string
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is in 0th element of list
ret = [i for i, e in enumerate(affinity_list) if e != 0]
return ret
def get_thread_siblings_list():
"""
Returns a list of 2-element integer tuples representing pairs of
hyperthreading cores.
"""
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
thread_siblings_list = []
pattern = re.compile(r"(\d+)\D(\d+)")
for fname in pathlib.Path(path[0]).glob(path[1:]):
with open(fname) as f:
content = f.read().strip()
res = pattern.findall(content)
if res:
pair = tuple(map(int, res[0]))
thread_siblings_list.append(pair)
return thread_siblings_list
def check_socket_affinities(socket_affinities):
# sets of cores should be either identical or disjoint
for i, j in itertools.product(socket_affinities, socket_affinities):
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
devices = [Device(i) for i in range(nproc_per_node)]
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
if exclude_unavailable_cores:
available_cores = os.sched_getaffinity(0)
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
check_socket_affinities(socket_affinities)
return socket_affinities
def set_socket_affinity(gpu_id):
"""
The process is assigned with all available logical CPU cores from the CPU
socket connected to the GPU with a given id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
os.sched_setaffinity(0, affinity)
def set_single_affinity(gpu_id):
"""
The process is assigned with the first available logical CPU core from the
list of all CPU cores from the CPU socket connected to the GPU with a given
id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
# exclude unavailable cores
available_cores = os.sched_getaffinity(0)
affinity = list(set(affinity) & available_cores)
os.sched_setaffinity(0, affinity[:1])
def set_single_unique_affinity(gpu_id, nproc_per_node):
"""
The process is assigned with a single unique available physical CPU core
from the list of all CPU cores from the CPU socket connected to the GPU with
a given id.
Args:
gpu_id: index of a GPU
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
affinities = []
assigned = []
for socket_affinity in socket_affinities:
for core in socket_affinity:
if core not in assigned:
affinities.append([core])
assigned.append(core)
break
os.sched_setaffinity(0, affinities[gpu_id])
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
"""
The process is assigned with an unique subset of available physical CPU
cores from the CPU socket connected to a GPU with a given id.
Assignment automatically includes hyperthreading siblings (if siblings are
available).
Args:
gpu_id: index of a GPU
nproc_per_node: total number of processes per node
mode: mode
balanced: assign an equal number of physical cores to each process
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove hyperthreading siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
socket_affinities_to_device_ids = collections.defaultdict(list)
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
# compute minimal number of physical cores per GPU across all GPUs and
# sockets, code assigns this number of cores per GPU if balanced == True
min_physical_cores_per_gpu = min(
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
)
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
devices_per_group = len(device_ids)
if balanced:
cores_per_device = min_physical_cores_per_gpu
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
else:
cores_per_device = len(socket_affinity) // devices_per_group
for group_id, device_id in enumerate(device_ids):
if device_id == gpu_id:
# In theory there should be no difference in performance between
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
# but 'continuous' should be better for DGX A100 because on AMD
# Rome 4 consecutive cores are sharing L3 cache.
# TODO: code doesn't attempt to automatically detect layout of
# L3 cache, also external environment may already exclude some
# cores, this code makes no attempt to detect it and to align
# mapping to multiples of 4.
if mode == "interleaved":
affinity = list(socket_affinity[group_id::devices_per_group])
elif mode == "continuous":
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
else:
raise RuntimeError("Unknown set_socket_unique_affinity mode")
# unconditionally reintroduce hyperthreading siblings, this step
# may result in a different numbers of logical cores assigned to
# each GPU even if balanced == True (if hyperthreading siblings
# aren't available for a subset of cores due to some external
# constraints, siblings are re-added unconditionally, in the
# worst case unavailable logical core will be ignored by
# os.sched_setaffinity().
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
os.sched_setaffinity(0, affinity)
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
"""
The process is assigned with a proper CPU affinity which matches hardware
architecture on a given platform. Usually it improves and stabilizes
performance of deep learning training workloads.
This function assumes that the workload is running in multi-process
single-device mode (there are multiple training processes and each process
is running on a single GPU), which is typical for multi-GPU training
workloads using `torch.nn.parallel.DistributedDataParallel`.
Available affinity modes:
* 'socket' - the process is assigned with all available logical CPU cores
from the CPU socket connected to the GPU with a given id.
* 'single' - the process is assigned with the first available logical CPU
core from the list of all CPU cores from the CPU socket connected to the GPU
with a given id (multiple GPUs could be assigned with the same CPU core).
* 'single_unique' - the process is assigned with a single unique available
physical CPU core from the list of all CPU cores from the CPU socket
connected to the GPU with a given id.
* 'socket_unique_interleaved' - the process is assigned with an unique
subset of available physical CPU cores from the CPU socket connected to a
GPU with a given id, hyperthreading siblings are included automatically,
cores are assigned with interleaved indexing pattern
* 'socket_unique_continuous' - (the default) the process is assigned with an
unique subset of available physical CPU cores from the CPU socket connected
to a GPU with a given id, hyperthreading siblings are included
automatically, cores are assigned with continuous indexing pattern
'socket_unique_continuous' is the recommended mode for deep learning
training workloads on NVIDIA DGX machines.
Args:
gpu_id: integer index of a GPU
nproc_per_node: number of processes per node
mode: affinity mode
balanced: assign an equal number of physical cores to each process,
affects only 'socket_unique_interleaved' and
'socket_unique_continuous' affinity modes
Returns a set of logical CPU cores on which the process is eligible to run.
Example:
import argparse
import os
import gpu_affinity
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--local_rank',
type=int,
default=os.getenv('LOCAL_RANK', 0),
)
args = parser.parse_args()
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
print(f'{args.local_rank}: core affinity: {affinity}')
if __name__ == "__main__":
main()
Launch the example with:
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
This function restricts execution only to the CPU cores directly connected
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
of CPU memory bandwidth (which may be fine for many DL models).
"""
pynvml.nvmlInit()
if mode == "socket":
set_socket_affinity(gpu_id)
elif mode == "single":
set_single_affinity(gpu_id)
elif mode == "single_unique":
set_single_unique_affinity(gpu_id, nproc_per_node)
elif mode == "socket_unique_interleaved":
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
elif mode == "socket_unique_continuous":
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
else:
raise RuntimeError("Unknown affinity mode")
affinity = os.sched_getaffinity(0)
return affinity

View file

@ -0,0 +1,131 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import List
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback
from se3_transformer.runtime.loggers import DLLogger
from se3_transformer.runtime.utils import to_cuda, get_local_rank
@torch.inference_mode()
def evaluate(model: nn.Module,
dataloader: DataLoader,
callbacks: List[BaseCallback],
args):
model.eval()
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
leave=False, disable=(args.silent or get_local_rank() != 0)):
*input, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*input)
for callback in callbacks:
callback.on_validation_step(input, target, pred)
if __name__ == '__main__':
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
from se3_transformer.runtime.utils import init_distributed, seed_everything
from se3_transformer.model import SE3TransformerPooled, Fiber
from se3_transformer.data_loading import QM9DataModule
import torch.distributed as dist
import logging
import sys
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Inference on the test set |')
logging.info('===============================')
if not args.benchmark and args.load_ckpt_path is None:
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
sys.exit(1)
if args.benchmark:
logging.info('Running benchmark mode with one warmup pass')
if args.seed is not None:
seed_everything(args.seed)
major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
**vars(args)
)
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
model.to(device=torch.cuda.current_device())
if args.load_ckpt_path is not None:
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
model.load_state_dict(checkpoint['state_dict'])
if is_distributed:
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model,
test_dataloader,
callbacks,
args)
for callback in callbacks:
callback.on_validation_end()
if args.benchmark:
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
for _ in range(6):
evaluate(model,
test_dataloader,
callbacks,
args)
callbacks[0].on_epoch_end()
callbacks[0].on_fit_end()

View file

@ -0,0 +1,134 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import pathlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any, Callable, Optional
import dllogger
import torch.distributed as dist
import wandb
from dllogger import Verbosity
from se3_transformer.runtime.utils import rank_zero_only
class Logger(ABC):
@rank_zero_only
@abstractmethod
def log_hyperparams(self, params):
pass
@rank_zero_only
@abstractmethod
def log_metrics(self, metrics, step=None):
pass
@staticmethod
def _sanitize_params(params):
def _sanitize(val):
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return getattr(val, "__name__", None)
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
return str(val)
return val
return {key: _sanitize(val) for key, val in params.items()}
class LoggerCollection(Logger):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def __getitem__(self, index):
return [logger for logger in self.loggers][index]
@rank_zero_only
def log_metrics(self, metrics, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, step)
@rank_zero_only
def log_hyperparams(self, params):
for logger in self.loggers:
logger.log_hyperparams(params)
class DLLogger(Logger):
def __init__(self, save_dir: pathlib.Path, filename: str):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dllogger.init(
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
@rank_zero_only
def log_hyperparams(self, params):
params = self._sanitize_params(params)
dllogger.log(step="PARAMETER", data=params)
@rank_zero_only
def log_metrics(self, metrics, step=None):
if step is None:
step = tuple()
dllogger.log(step=step, data=metrics)
class WandbLogger(Logger):
def __init__(
self,
name: str,
save_dir: pathlib.Path,
id: Optional[str] = None,
project: Optional[str] = None
):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
self.experiment = wandb.init(name=name,
project=project,
id=id,
dir=str(save_dir),
resume='allow',
anonymous='must')
@rank_zero_only
def log_hyperparams(self, params: Dict[str, Any]) -> None:
params = self._sanitize_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
if step is not None:
self.experiment.log({**metrics, 'epoch': step})
else:
self.experiment.log(metrics)

View file

@ -0,0 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
from torch import Tensor
class Metric(ABC):
""" Metric class with synchronization capabilities similar to TorchMetrics """
def __init__(self):
self.states = {}
def add_state(self, name: str, default: Tensor):
assert name not in self.states
self.states[name] = default.clone()
setattr(self, name, default)
def synchronize(self):
if dist.is_initialized():
for state in self.states:
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
def reset(self):
for name, default in self.states.items():
setattr(self, name, default.clone())
def compute(self):
self.synchronize()
value = self._compute().item()
self.reset()
return value
@abstractmethod
def _compute(self):
pass
@abstractmethod
def update(self, preds: Tensor, targets: Tensor):
pass
class MeanAbsoluteError(Metric):
def __init__(self):
super().__init__()
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
def update(self, preds: Tensor, targets: Tensor):
preds = preds.detach()
n = preds.shape[0]
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
self.total += n
self.error += error
def _compute(self):
return self.error / self.total

View file

@ -0,0 +1,238 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import pathlib
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = {
'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
for callback in callbacks:
callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks:
callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch']
def train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward()
# gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer)
grad_scaler.update()
optimizer.zero_grad()
losses.append(loss.item())
return np.mean(losses)
def train(model: nn.Module,
loss_fn: _Loss,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
callbacks: List[BaseCallback],
logger: Logger,
args):
device = torch.cuda.current_device()
model.to(device=device)
local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
if args.optimizer == 'adam':
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
elif args.optimizer == 'lamb':
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
loss = (loss / world_size).item()
logging.info(f'Train loss: {loss}')
logger.log_metrics({'train loss': loss}, epoch_idx)
for callback in callbacks:
callback.on_epoch_end()
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
evaluate(model, val_dataloader, callbacks, args)
model.train()
for callback in callbacks:
callback.on_validation_end(epoch_idx)
if args.save_ckpt_path is not None and not args.benchmark:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
callback.on_fit_end()
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
if __name__ == '__main__':
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |')
logging.info('===============================')
if args.seed is not None:
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
logger = LoggerCollection([
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
])
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
)
loss_fn = nn.L1Loss()
if args.benchmark:
logging.info('Running benchmark mode')
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
if is_distributed:
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
print_parameters_count(model)
logger.log_hyperparams(vars(args))
increase_l2_fetch_granularity()
train(model,
loss_fn,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
callbacks,
logger,
args)
logging.info('Training finished successfully')

View file

@ -0,0 +1,130 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import ctypes
import logging
import os
import random
from functools import wraps
from typing import Union, List, Dict
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
def aggregate_residual(feats1, feats2, method: str):
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
if method in ['add', 'sum']:
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
elif method in ['cat', 'concat']:
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
else:
raise ValueError('Method must be add/sum or cat/concat')
def degree_to_dim(degree: int) -> int:
return 2 * degree + 1
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
def str2bool(v: Union[bool, str]) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def to_cuda(x):
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
if isinstance(x, Tensor):
return x.cuda(non_blocking=True)
elif isinstance(x, tuple):
return (to_cuda(v) for v in x)
elif isinstance(x, list):
return [to_cuda(v) for v in x]
elif isinstance(x, dict):
return {k: to_cuda(v) for k, v in x.items()}
else:
# DGLGraph or other objects
return x.to(device=torch.cuda.current_device())
def get_local_rank() -> int:
return int(os.environ.get('LOCAL_RANK', 0))
def init_distributed() -> bool:
world_size = int(os.environ.get('WORLD_SIZE', 1))
distributed = world_size > 1
if distributed:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend, init_method='env://')
if backend == 'nccl':
torch.cuda.set_device(get_local_rank())
else:
logging.warning('Running on CPU only!')
assert torch.distributed.is_initialized()
return distributed
def increase_l2_fetch_granularity():
# maximum fetch granularity of L2: 128 bytes
_libcudart = ctypes.CDLL('libcudart.so')
# set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128
def seed_everything(seed):
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def rank_zero_only(fn):
@wraps(fn)
def wrapped_fn(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
return fn(*args, **kwargs)
return wrapped_fn
def using_tensor_cores(amp: bool) -> bool:
major_cc, minor_cc = torch.cuda.get_device_capability()
return (amp and major_cc >= 7) or major_cc >= 8

View file

@ -0,0 +1,11 @@
from setuptools import setup, find_packages
setup(
name='se3-transformer',
packages=find_packages(),
include_package_data=True,
version='1.0.0',
description='PyTorch + DGL implementation of SE(3)-Transformers',
author='Alexandre Milesi',
author_email='alexandrem@nvidia.com',
)

View file

@ -0,0 +1,102 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
TOL = 1e-3
CHANNELS, NODES = 32, 512
def _get_outputs(model, R):
feats0 = torch.randn(NODES, CHANNELS, 1)
feats1 = torch.randn(NODES, CHANNELS, 3)
coords = torch.randn(NODES, 3)
graph = get_random_graph(NODES)
if torch.cuda.is_available():
feats0 = feats0.cuda()
feats1 = feats1.cuda()
R = R.cuda()
coords = coords.cuda()
graph = graph.to('cuda')
model.cuda()
graph1 = assign_relative_pos(graph, coords)
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
graph2 = assign_relative_pos(graph, coords @ R)
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
return out1, out2
def _get_model(**kwargs):
return SE3Transformer(
num_layers=4,
fiber_in=Fiber.create(2, CHANNELS),
fiber_hidden=Fiber.create(3, CHANNELS),
fiber_out=Fiber.create(2, CHANNELS),
fiber_edge=Fiber({}),
num_heads=8,
channels_div=2,
**kwargs
)
def test_equivariance():
model = _get_model()
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
def test_equivariance_pooled():
model = _get_model(pooling='avg', return_type=1)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
def test_invariance_pooled():
model = _get_model(pooling='avg', return_type=0)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, out1, atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1, out2)}'

View file

@ -0,0 +1,60 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import torch
def get_random_graph(N, num_edges_factor=18):
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
return graph
def assign_relative_pos(graph, coords):
src, dst = graph.edges()
graph.edata['rel_pos'] = coords[src] - coords[dst]
return graph
def get_max_diff(a, b):
return (a - b).abs().max().item()
def rot_z(gamma):
return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)