[SE3Transformer/DGLPyT] Better low memory mode

This commit is contained in:
Alexandre Milesi 2021-11-01 17:49:17 +01:00
parent dcd3bbac09
commit 22d6621dcd
19 changed files with 1610 additions and 1574 deletions

View File

@ -42,7 +42,6 @@ RUN make -j8
FROM ${FROM_IMAGE_NAME}
RUN rm -rf /workspace/*
WORKDIR /workspace/se3-transformer
# copy built DGL and install it
@ -55,3 +54,5 @@ ADD . .
ENV DGLBACKEND=pytorch
ENV OMP_NUM_THREADS=1

View File

@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model:
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode:
| | FP32 | AMP
|---|-----------------------|--------------------------
|`--low_memory false` (default) | 4.7x | 7.1x
|`--low_memory true` | 29.4x | 43.6x
**Self-attention optimizations**
@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta
- `--pooling`: Type of graph pooling (default: `max`)
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
- `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`)
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
- `--num_channels`: Number of channels for the hidden features (default: `32`)
@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag.
If no API key is set, `--wandb` will log the run anonymously to Weights & Biases.
**Checkpoints**
@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog
November 2021:
- Improved low memory mode to give further 6x memory savings
- Disabled W&B logging by default
- Fixed persistent workers when using one data loading process
October 2021:
- Updated README performance tables
- Fixed shape mismatch when using partially fused TFNs per output degree

View File

@ -46,7 +46,8 @@ class DataModule(ABC):
if dist.is_initialized():
dist.barrier(device_ids=[get_local_rank()])
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
**dataloader_kwargs}
self.ds_train, self.ds_val, self.ds_test = None, None, None
def prepare_data(self):

View File

@ -1,2 +1,2 @@
from .transformer import SE3Transformer, SE3TransformerPooled
from .fiber import Fiber
from .transformer import SE3Transformer, SE3TransformerPooled
from .fiber import Fiber

View File

@ -1,144 +1,144 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from collections import namedtuple
from itertools import product
from typing import Dict
import torch
from torch import Tensor
from se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
class Fiber(dict):
"""
Describes the structure of some set of features.
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
Type-0 features: invariant scalars
Type-1 features: equivariant 3D vectors
Type-2 features: equivariant symmetric traceless matrices
...
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
The 'multiplicity' or 'number of channels' is the number of features of a given type.
This class puts together all the degrees and their multiplicities in order to describe
the inputs, outputs or hidden features of SE3 layers.
"""
def __init__(self, structure):
if isinstance(structure, dict):
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
elif not isinstance(structure[0], FiberEl):
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
self.structure = structure
super().__init__({d: m for d, m in self.structure})
@property
def degrees(self):
return sorted([t.degree for t in self.structure])
@property
def channels(self):
return [self[d] for d in self.degrees]
@property
def num_features(self):
""" Size of the resulting tensor if all features were concatenated together """
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
@staticmethod
def create(num_degrees: int, num_channels: int):
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
@staticmethod
def from_features(feats: Dict[str, Tensor]):
""" Infer the Fiber structure from a feature dict """
structure = {}
for k, v in feats.items():
degree = int(k)
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
assert v.shape[-1] == degree_to_dim(degree)
structure[degree] = v.shape[-2]
return Fiber(structure)
def __getitem__(self, degree: int):
""" fiber[degree] returns the multiplicity for this degree """
return dict(self.structure).get(degree, 0)
def __iter__(self):
""" Iterate over namedtuples (degree, channels) """
return iter(self.structure)
def __mul__(self, other):
"""
If other in an int, multiplies all the multiplicities by other.
If other is a fiber, returns the cartesian product.
"""
if isinstance(other, Fiber):
return product(self.structure, other.structure)
elif isinstance(other, int):
return Fiber({t.degree: t.channels * other for t in self.structure})
def __add__(self, other):
"""
If other in an int, add other to all the multiplicities.
If other is a fiber, add the multiplicities of the fibers together.
"""
if isinstance(other, Fiber):
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
elif isinstance(other, int):
return Fiber({t.degree: t.channels + other for t in self.structure})
def __repr__(self):
return str(self.structure)
@staticmethod
def combine_max(f1, f2):
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
new_dict = dict(f1.structure)
for k, m in f2.structure:
new_dict[k] = max(new_dict.get(k, 0), m)
return Fiber(list(new_dict.items()))
@staticmethod
def combine_selectively(f1, f2):
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
# only use orders which occur in fiber f1
new_dict = dict(f1.structure)
for k in f1.degrees:
if k in f2.degrees:
new_dict[k] += f2[k]
return Fiber(list(new_dict.items()))
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
self.degrees]
fibers = torch.cat(fibers, -1)
return fibers
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from collections import namedtuple
from itertools import product
from typing import Dict
import torch
from torch import Tensor
from se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
class Fiber(dict):
"""
Describes the structure of some set of features.
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
Type-0 features: invariant scalars
Type-1 features: equivariant 3D vectors
Type-2 features: equivariant symmetric traceless matrices
...
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
The 'multiplicity' or 'number of channels' is the number of features of a given type.
This class puts together all the degrees and their multiplicities in order to describe
the inputs, outputs or hidden features of SE3 layers.
"""
def __init__(self, structure):
if isinstance(structure, dict):
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
elif not isinstance(structure[0], FiberEl):
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
self.structure = structure
super().__init__({d: m for d, m in self.structure})
@property
def degrees(self):
return sorted([t.degree for t in self.structure])
@property
def channels(self):
return [self[d] for d in self.degrees]
@property
def num_features(self):
""" Size of the resulting tensor if all features were concatenated together """
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
@staticmethod
def create(num_degrees: int, num_channels: int):
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
@staticmethod
def from_features(feats: Dict[str, Tensor]):
""" Infer the Fiber structure from a feature dict """
structure = {}
for k, v in feats.items():
degree = int(k)
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
assert v.shape[-1] == degree_to_dim(degree)
structure[degree] = v.shape[-2]
return Fiber(structure)
def __getitem__(self, degree: int):
""" fiber[degree] returns the multiplicity for this degree """
return dict(self.structure).get(degree, 0)
def __iter__(self):
""" Iterate over namedtuples (degree, channels) """
return iter(self.structure)
def __mul__(self, other):
"""
If other in an int, multiplies all the multiplicities by other.
If other is a fiber, returns the cartesian product.
"""
if isinstance(other, Fiber):
return product(self.structure, other.structure)
elif isinstance(other, int):
return Fiber({t.degree: t.channels * other for t in self.structure})
def __add__(self, other):
"""
If other in an int, add other to all the multiplicities.
If other is a fiber, add the multiplicities of the fibers together.
"""
if isinstance(other, Fiber):
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
elif isinstance(other, int):
return Fiber({t.degree: t.channels + other for t in self.structure})
def __repr__(self):
return str(self.structure)
@staticmethod
def combine_max(f1, f2):
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
new_dict = dict(f1.structure)
for k, m in f2.structure:
new_dict[k] = max(new_dict.get(k, 0), m)
return Fiber(list(new_dict.items()))
@staticmethod
def combine_selectively(f1, f2):
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
# only use orders which occur in fiber f1
new_dict = dict(f1.structure)
for k in f1.degrees:
if k in f2.degrees:
new_dict[k] += f2[k]
return Fiber(list(new_dict.items()))
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
self.degrees]
fibers = torch.cat(fibers, -1)
return fibers

View File

@ -1,5 +1,5 @@
from .linear import LinearSE3
from .norm import NormSE3
from .pooling import GPooling
from .convolution import ConvSE3
from .linear import LinearSE3
from .norm import NormSE3
from .pooling import GPooling
from .convolution import ConvSE3
from .attention import AttentionBlockSE3

View File

@ -1,180 +1,181 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from dgl.ops import edge_softmax
from torch import Tensor
from typing import Dict, Optional, Union
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range
class AttentionSE3(nn.Module):
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
def __init__(
self,
num_heads: int,
key_fiber: Fiber,
value_fiber: Fiber
):
"""
:param num_heads: Number of attention heads
:param key_fiber: Fiber for the keys (and also for the queries)
:param value_fiber: Fiber for the values
"""
super().__init__()
self.num_heads = num_heads
self.key_fiber = key_fiber
self.value_fiber = value_fiber
def forward(
self,
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
query: Dict[str, Tensor], # node features
graph: DGLGraph
):
with nvtx_range('AttentionSE3'):
with nvtx_range('reshape keys and queries'):
if isinstance(key, Tensor):
# case where features of all types are fused
key = key.reshape(key.shape[0], self.num_heads, -1)
# need to reshape queries that way to keep the same layout as keys
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
else:
# features are not fused, need to fuse and reshape them
key = self.key_fiber.to_attention_heads(key, self.num_heads)
query = self.key_fiber.to_attention_heads(query, self.num_heads)
with nvtx_range('attention dot product + softmax'):
# Compute attention weights (softmax of inner product between key and query)
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
edge_weights = edge_softmax(graph, edge_weights)
edge_weights = edge_weights[..., None, None]
with nvtx_range('weighted sum'):
if isinstance(value, Tensor):
# features of all types are fused
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
weights = edge_weights * v
feat_out = dgl.ops.copy_e_sum(graph, weights)
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
out = unfuse_features(feat_out, self.value_fiber.degrees)
else:
out = {}
for degree, channels in self.value_fiber:
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
degree_to_dim(degree))
weights = edge_weights * v
res = dgl.ops.copy_e_sum(graph, weights)
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
return out
class AttentionBlockSE3(nn.Module):
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Optional[Fiber] = None,
num_heads: int = 4,
channels_div: int = 2,
use_layer_norm: bool = False,
max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
**kwargs
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param num_heads: Number of attention heads
:param channels_div: Divide the channels by this integer for computing values
:param use_layer_norm: Apply layer normalization between MLP layers
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
"""
super().__init__()
if fiber_edge is None:
fiber_edge = Fiber({})
self.fiber_in = fiber_in
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
# (queries are merely projected, hence degrees have to match input)
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
allow_fused_output=True)
self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def forward(
self,
node_features: Dict[str, Tensor],
edge_features: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range('AttentionBlockSE3'):
with nvtx_range('keys / values'):
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
key, value = self._get_key_value_from_fused(fused_key_value)
with nvtx_range('queries'):
query = self.to_query(node_features)
z = self.attention(value, key, query, graph)
z_concat = aggregate_residual(node_features, z, 'cat')
return self.project(z_concat)
def _get_key_value_from_fused(self, fused_key_value):
# Extract keys and queries features from fused features
if isinstance(fused_key_value, Tensor):
# Previous layer was a fully fused convolution
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
else:
key, value = {}, {}
for degree, feat in fused_key_value.items():
if int(degree) in self.fiber_in.degrees:
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
else:
value[degree] = feat
return key, value
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from dgl.ops import edge_softmax
from torch import Tensor
from typing import Dict, Optional, Union
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range
class AttentionSE3(nn.Module):
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
def __init__(
self,
num_heads: int,
key_fiber: Fiber,
value_fiber: Fiber
):
"""
:param num_heads: Number of attention heads
:param key_fiber: Fiber for the keys (and also for the queries)
:param value_fiber: Fiber for the values
"""
super().__init__()
self.num_heads = num_heads
self.key_fiber = key_fiber
self.value_fiber = value_fiber
def forward(
self,
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
query: Dict[str, Tensor], # node features
graph: DGLGraph
):
with nvtx_range('AttentionSE3'):
with nvtx_range('reshape keys and queries'):
if isinstance(key, Tensor):
# case where features of all types are fused
key = key.reshape(key.shape[0], self.num_heads, -1)
# need to reshape queries that way to keep the same layout as keys
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
else:
# features are not fused, need to fuse and reshape them
key = self.key_fiber.to_attention_heads(key, self.num_heads)
query = self.key_fiber.to_attention_heads(query, self.num_heads)
with nvtx_range('attention dot product + softmax'):
# Compute attention weights (softmax of inner product between key and query)
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
edge_weights = edge_softmax(graph, edge_weights)
edge_weights = edge_weights[..., None, None]
with nvtx_range('weighted sum'):
if isinstance(value, Tensor):
# features of all types are fused
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
weights = edge_weights * v
feat_out = dgl.ops.copy_e_sum(graph, weights)
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
out = unfuse_features(feat_out, self.value_fiber.degrees)
else:
out = {}
for degree, channels in self.value_fiber:
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
degree_to_dim(degree))
weights = edge_weights * v
res = dgl.ops.copy_e_sum(graph, weights)
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
return out
class AttentionBlockSE3(nn.Module):
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Optional[Fiber] = None,
num_heads: int = 4,
channels_div: int = 2,
use_layer_norm: bool = False,
max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
low_memory: bool = False,
**kwargs
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param num_heads: Number of attention heads
:param channels_div: Divide the channels by this integer for computing values
:param use_layer_norm: Apply layer normalization between MLP layers
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
"""
super().__init__()
if fiber_edge is None:
fiber_edge = Fiber({})
self.fiber_in = fiber_in
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
# (queries are merely projected, hence degrees have to match input)
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
allow_fused_output=True, low_memory=low_memory)
self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def forward(
self,
node_features: Dict[str, Tensor],
edge_features: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range('AttentionBlockSE3'):
with nvtx_range('keys / values'):
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
key, value = self._get_key_value_from_fused(fused_key_value)
with nvtx_range('queries'):
query = self.to_query(node_features)
z = self.attention(value, key, query, graph)
z_concat = aggregate_residual(node_features, z, 'cat')
return self.project(z_concat)
def _get_key_value_from_fused(self, fused_key_value):
# Extract keys and queries features from fused features
if isinstance(fused_key_value, Tensor):
# Previous layer was a fully fused convolution
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
else:
key, value = {}, {}
for degree, feat in fused_key_value.items():
if int(degree) in self.fiber_in.degrees:
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
else:
value[degree] = feat
return key, value

View File

@ -1,345 +1,354 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from enum import Enum
from itertools import product
from typing import Dict
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):
"""
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
A higher level means faster training, but also more memory usage.
If you are tight on memory and want to feed large inputs to the network, choose a low value.
If you want to train fast, choose a high value.
Recommended value is FULL with AMP.
Fully fused TFN convolutions requirements:
- all input channels are the same
- all output channels are the same
- input degrees span the range [0, ..., max_degree]
- output degrees span the range [0, ..., max_degree]
Partially fused TFN convolutions requirements:
* For fusing by output degree:
- all input channels are the same
- input degrees span the range [0, ..., max_degree]
* For fusing by input degree:
- all output channels are the same
- output degrees span the range [0, ..., max_degree]
Original TFN pairwise convolutions: no requirements
"""
FULL = 2
PARTIAL = 1
NONE = 0
class RadialProfile(nn.Module):
"""
Radial profile function.
Outputs weights used to weigh basis matrices in order to get convolution kernels.
In TFN notation: $R^{l,k}$
In SE(3)-Transformer notation: $\phi^{l,k}$
Note:
In the original papers, this function only depends on relative node distances ||x||.
Here, we allow this function to also take as input additional invariant edge features.
This does not break equivariance and adds expressive power to the model.
Diagram:
invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights
"""
def __init__(
self,
num_freq: int,
channels_in: int,
channels_out: int,
edge_dim: int = 1,
mid_dim: int = 32,
use_layer_norm: bool = False
):
"""
:param num_freq: Number of frequencies
:param channels_in: Number of input channels
:param channels_out: Number of output channels
:param edge_dim: Number of invariant edge features (input to the radial function)
:param mid_dim: Size of the hidden MLP layers
:param use_layer_norm: Apply layer normalization between MLP layers
"""
super().__init__()
modules = [
nn.Linear(edge_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
]
self.net = nn.Sequential(*[m for m in modules if m is not None])
def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class VersatileConvSE3(nn.Module):
"""
Building block for TFN convolutions.
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
"""
def __init__(self,
freq_sum: int,
channels_in: int,
channels_out: int,
edge_dim: int,
use_layer_norm: bool,
fuse_level: ConvSE3FuseLevel):
super().__init__()
self.freq_sum = freq_sum
self.channels_out = channels_out
self.channels_in = channels_in
self.fuse_level = fuse_level
self.radial_func = RadialProfile(num_freq=freq_sum,
channels_in=channels_in,
channels_out=channels_out,
edge_dim=edge_dim,
use_layer_norm=use_layer_norm)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
with nvtx_range(f'VersatileConvSE3'):
num_edges = features.shape[0]
in_dim = features.shape[2]
with nvtx_range(f'RadialProfile'):
radial_weights = self.radial_func(invariant_edge_feats) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
basis_view = basis.view(num_edges, in_dim, -1)
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
return radial_weights @ tmp
else:
# k = l = 0 non-fused case
return radial_weights @ features
class ConvSE3(nn.Module):
"""
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
Features of different degrees interact together to produce output features.
Note 1:
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
done, and the returned features will be edge features instead of node features.
Note 2:
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
Input edge features are concatenated with input source node features before the kernel is applied.
"""
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
pool: bool = True,
use_layer_norm: bool = False,
self_interaction: bool = False,
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param pool: If True, compute final node features by averaging incoming edge features
:param use_layer_norm: Apply layer normalization between MLP layers
:param self_interaction: Apply self-interaction of nodes
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
:param allow_fused_output: Allow the module to output a fused representation of features
"""
super().__init__()
self.pool = pool
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.self_interaction = self_interaction
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
# channels_in: account for the concatenation of edge features
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
channels_out_set = set([f.channels for f in self.fiber_out])
unique_channels_in = (len(channels_in_set) == 1)
unique_channels_out = (len(channels_out_set) == 1)
degrees_up_to_max = list(range(max_degree + 1))
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Single fused convolution
self.used_fuse_level = ConvSE3FuseLevel.FULL
sum_freq = sum([
degree_to_dim(min(d_in, d_out))
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
])
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
# Convolutions fused per output degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_out = nn.ModuleDict()
for d_out, c_out in fiber_out:
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Convolutions fused per input degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
else:
# Use pairwise TFN convolutions
self.used_fuse_level = ConvSE3FuseLevel.NONE
self.conv = nn.ModuleDict()
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
dict_key = f'{degree_in},{degree_out}'
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
sum_freq = degree_to_dim(min(degree_in, degree_out))
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
fuse_level=self.used_fuse_level, **common_args)
if self_interaction:
self.to_kernel_self = nn.ParameterDict()
for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]:
self.to_kernel_self[str(degree_out)] = nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
def _try_unpad(self, feature, basis):
# Account for padded basis
if basis is not None:
out_dim = basis.shape[-1]
out_dim += out_dim % 2 - 1
return feature[..., :out_dim]
else:
return feature
def forward(
self,
node_feats: Dict[str, Tensor],
edge_feats: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range(f'ConvSE3'):
invariant_edge_feats = edge_feats['0'].squeeze(-1)
src, dst = graph.edges()
out = {}
in_features = []
# Fetch all input features from edge and node features
for degree_in in self.fiber_in.degrees:
src_node_features = node_feats[str(degree_in)][src]
if degree_in > 0 and str(degree_in) in edge_feats:
# Handle edge features of any type by concatenating them to node features
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
in_features.append(src_node_features)
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
in_features_fused = torch.cat(in_features, dim=-1)
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
basis_used = basis[f'out{degree_out}_fused']
out[str(degree_out)] = self._try_unpad(
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
basis_used)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
else:
# Fallback to pairwise TFN convolutions
for degree_out in self.fiber_out.degrees:
out_feature = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
dict_key = f'{degree_in},{degree_out}'
basis_used = basis.get(dict_key, None)
out_feature = out_feature + self._try_unpad(
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
basis_used)
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:
if self.self_interaction and str(degree_out) in self.to_kernel_self:
with nvtx_range(f'self interaction'):
dst_features = node_feats[str(degree_out)][dst]
kernel_self = self.to_kernel_self[str(degree_out)]
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
if self.pool:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_sum(graph, out)
return out
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from enum import Enum
from itertools import product
from typing import Dict
import dgl
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):
"""
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
A higher level means faster training, but also more memory usage.
If you are tight on memory and want to feed large inputs to the network, choose a low value.
If you want to train fast, choose a high value.
Recommended value is FULL with AMP.
Fully fused TFN convolutions requirements:
- all input channels are the same
- all output channels are the same
- input degrees span the range [0, ..., max_degree]
- output degrees span the range [0, ..., max_degree]
Partially fused TFN convolutions requirements:
* For fusing by output degree:
- all input channels are the same
- input degrees span the range [0, ..., max_degree]
* For fusing by input degree:
- all output channels are the same
- output degrees span the range [0, ..., max_degree]
Original TFN pairwise convolutions: no requirements
"""
FULL = 2
PARTIAL = 1
NONE = 0
class RadialProfile(nn.Module):
"""
Radial profile function.
Outputs weights used to weigh basis matrices in order to get convolution kernels.
In TFN notation: $R^{l,k}$
In SE(3)-Transformer notation: $\phi^{l,k}$
Note:
In the original papers, this function only depends on relative node distances ||x||.
Here, we allow this function to also take as input additional invariant edge features.
This does not break equivariance and adds expressive power to the model.
Diagram:
invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights
"""
def __init__(
self,
num_freq: int,
channels_in: int,
channels_out: int,
edge_dim: int = 1,
mid_dim: int = 32,
use_layer_norm: bool = False
):
"""
:param num_freq: Number of frequencies
:param channels_in: Number of input channels
:param channels_out: Number of output channels
:param edge_dim: Number of invariant edge features (input to the radial function)
:param mid_dim: Size of the hidden MLP layers
:param use_layer_norm: Apply layer normalization between MLP layers
"""
super().__init__()
modules = [
nn.Linear(edge_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
]
self.net = nn.Sequential(*[m for m in modules if m is not None])
def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class VersatileConvSE3(nn.Module):
"""
Building block for TFN convolutions.
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
"""
def __init__(self,
freq_sum: int,
channels_in: int,
channels_out: int,
edge_dim: int,
use_layer_norm: bool,
fuse_level: ConvSE3FuseLevel):
super().__init__()
self.freq_sum = freq_sum
self.channels_out = channels_out
self.channels_in = channels_in
self.fuse_level = fuse_level
self.radial_func = RadialProfile(num_freq=freq_sum,
channels_in=channels_in,
channels_out=channels_out,
edge_dim=edge_dim,
use_layer_norm=use_layer_norm)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
with nvtx_range(f'VersatileConvSE3'):
num_edges = features.shape[0]
in_dim = features.shape[2]
with nvtx_range(f'RadialProfile'):
radial_weights = self.radial_func(invariant_edge_feats) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
basis_view = basis.view(num_edges, in_dim, -1)
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
return radial_weights @ tmp
else:
# k = l = 0 non-fused case
return radial_weights @ features
class ConvSE3(nn.Module):
"""
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
Features of different degrees interact together to produce output features.
Note 1:
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
done, and the returned features will be edge features instead of node features.
Note 2:
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
Input edge features are concatenated with input source node features before the kernel is applied.
"""
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
pool: bool = True,
use_layer_norm: bool = False,
self_interaction: bool = False,
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False,
low_memory: bool = False
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param pool: If True, compute final node features by averaging incoming edge features
:param use_layer_norm: Apply layer normalization between MLP layers
:param self_interaction: Apply self-interaction of nodes
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
:param allow_fused_output: Allow the module to output a fused representation of features
"""
super().__init__()
self.pool = pool
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.self_interaction = self_interaction
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
# channels_in: account for the concatenation of edge features
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
channels_out_set = set([f.channels for f in self.fiber_out])
unique_channels_in = (len(channels_in_set) == 1)
unique_channels_out = (len(channels_out_set) == 1)
degrees_up_to_max = list(range(max_degree + 1))
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Single fused convolution
self.used_fuse_level = ConvSE3FuseLevel.FULL
sum_freq = sum([
degree_to_dim(min(d_in, d_out))
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
])
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
# Convolutions fused per output degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_out = nn.ModuleDict()
for d_out, c_out in fiber_out:
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Convolutions fused per input degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
else:
# Use pairwise TFN convolutions
self.used_fuse_level = ConvSE3FuseLevel.NONE
self.conv = nn.ModuleDict()
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
dict_key = f'{degree_in},{degree_out}'
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
sum_freq = degree_to_dim(min(degree_in, degree_out))
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
fuse_level=self.used_fuse_level, **common_args)
if self_interaction:
self.to_kernel_self = nn.ParameterDict()
for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]:
self.to_kernel_self[str(degree_out)] = nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
def _try_unpad(self, feature, basis):
# Account for padded basis
if basis is not None:
out_dim = basis.shape[-1]
out_dim += out_dim % 2 - 1
return feature[..., :out_dim]
else:
return feature
def forward(
self,
node_feats: Dict[str, Tensor],
edge_feats: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range(f'ConvSE3'):
invariant_edge_feats = edge_feats['0'].squeeze(-1)
src, dst = graph.edges()
out = {}
in_features = []
# Fetch all input features from edge and node features
for degree_in in self.fiber_in.degrees:
src_node_features = node_feats[str(degree_in)][src]
if degree_in > 0 and str(degree_in) in edge_feats:
# Handle edge features of any type by concatenating them to node features
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
in_features.append(src_node_features)
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
in_features_fused = torch.cat(in_features, dim=-1)
out = self.conv_checkpoint(
self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused']
)
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
basis_used = basis[f'out{degree_out}_fused']
out[str(degree_out)] = self._try_unpad(
self.conv_checkpoint(
self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used
), basis_used)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out = out + self.conv_checkpoint(
self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused']
)
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
else:
# Fallback to pairwise TFN convolutions
for degree_out in self.fiber_out.degrees:
out_feature = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
dict_key = f'{degree_in},{degree_out}'
basis_used = basis.get(dict_key, None)
out_feature = out_feature + self._try_unpad(
self.conv_checkpoint(
self.conv[dict_key], feature, invariant_edge_feats, basis_used
), basis_used)
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:
if self.self_interaction and str(degree_out) in self.to_kernel_self:
with nvtx_range(f'self interaction'):
dst_features = node_feats[str(degree_out)][dst]
kernel_self = self.to_kernel_self[str(degree_out)]
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
if self.pool:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_sum(graph, out)
return out

View File

@ -1,59 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module):
"""
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
Maps a fiber to a fiber with the same degrees (channels may be different).
No interaction between degrees, but interaction between channels.
type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels)
type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels)
:
type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels)
"""
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
super().__init__()
self.weights = nn.ParameterDict({
str(degree_out): nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
for degree_out, channels_out in fiber_out
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
return {
degree: self.weights[degree] @ features[degree]
for degree, weight in self.weights.items()
}
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module):
"""
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
Maps a fiber to a fiber with the same degrees (channels may be different).
No interaction between degrees, but interaction between channels.
type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels)
type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels)
:
type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels)
"""
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
super().__init__()
self.weights = nn.ParameterDict({
str(degree_out): nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
for degree_out, channels_out in fiber_out
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
return {
degree: self.weights[degree] @ features[degree]
for degree, weight in self.weights.items()
}

View File

@ -1,83 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module):
"""
Norm-based SE(3)-equivariant nonlinearity.
> feature_norm > LayerNorm() > ReLU()
feature_in * > feature_out
> feature_phase
"""
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__()
self.fiber = fiber
self.nonlinearity = nonlinearity
if len(set(fiber.channels)) == 1:
# Fuse all the layer normalizations into a group normalization
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
else:
# Use multiple layer normalizations
self.layer_norms = nn.ModuleDict({
str(degree): nn.LayerNorm(channels)
for degree, channels in fiber
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
with nvtx_range('NormSE3'):
output = {}
if hasattr(self, 'group_norm'):
# Compute per-degree norms of features
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
for d in self.fiber.degrees]
fused_norms = torch.cat(norms, dim=-2)
# Transform the norms only
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
# Scale features to the new norms
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
output[str(d)] = features[str(d)] / norm * new_norm
else:
for degree, feat in features.items():
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
output[degree] = new_norm * feat / norm
return output
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module):
"""
Norm-based SE(3)-equivariant nonlinearity.
> feature_norm > LayerNorm() > ReLU()
feature_in * > feature_out
> feature_phase
"""
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__()
self.fiber = fiber
self.nonlinearity = nonlinearity
if len(set(fiber.channels)) == 1:
# Fuse all the layer normalizations into a group normalization
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
else:
# Use multiple layer normalizations
self.layer_norms = nn.ModuleDict({
str(degree): nn.LayerNorm(channels)
for degree, channels in fiber
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
with nvtx_range('NormSE3'):
output = {}
if hasattr(self, 'group_norm'):
# Compute per-degree norms of features
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
for d in self.fiber.degrees]
fused_norms = torch.cat(norms, dim=-2)
# Transform the norms only
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
# Scale features to the new norms
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
output[str(d)] = features[str(d)] / norm * new_norm
else:
for degree, feat in features.items():
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
output[degree] = new_norm * feat / norm
return output

View File

@ -1,53 +1,53 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict, Literal
import torch.nn as nn
from dgl import DGLGraph
from dgl.nn.pytorch import AvgPooling, MaxPooling
from torch import Tensor
class GPooling(nn.Module):
"""
Graph max/average pooling on a given feature type.
The average can be taken for any feature type, and equivariance will be maintained.
The maximum can only be taken for invariant features (type 0).
If you want max-pooling for type > 0 features, look into Vector Neurons.
"""
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
"""
:param feat_type: Feature type to pool
:param pool: Type of pooling: max or avg
"""
super().__init__()
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
self.feat_type = feat_type
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
pooled = self.pool(graph, features[str(self.feat_type)])
return pooled.squeeze(dim=-1)
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict, Literal
import torch.nn as nn
from dgl import DGLGraph
from dgl.nn.pytorch import AvgPooling, MaxPooling
from torch import Tensor
class GPooling(nn.Module):
"""
Graph max/average pooling on a given feature type.
The average can be taken for any feature type, and equivariance will be maintained.
The maximum can only be taken for invariant features (type 0).
If you want max-pooling for type > 0 features, look into Vector Neurons.
"""
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
"""
:param feat_type: Feature type to pool
:param pool: Type of pooling: max or avg
"""
super().__init__()
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
self.feat_type = feat_type
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
pooled = self.pool(graph, features[str(self.feat_type)])
return pooled.squeeze(dim=-1)

View File

@ -1,222 +1,223 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
from typing import Optional, Literal, Dict
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential):
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
def forward(self, input, *args, **kwargs):
for module in self:
input = module(input, *args, **kwargs)
return input
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
""" Add relative positions to existing edge features """
edge_features = edge_features.copy() if edge_features else {}
r = relative_pos.norm(dim=-1, keepdim=True)
if '0' in edge_features:
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
else:
edge_features['0'] = r[..., None]
return edge_features
class SE3Transformer(nn.Module):
def __init__(self,
num_layers: int,
fiber_in: Fiber,
fiber_hidden: Fiber,
fiber_out: Fiber,
num_heads: int,
channels_div: int,
fiber_edge: Fiber = Fiber({}),
return_type: Optional[int] = None,
pooling: Optional[Literal['avg', 'max']] = None,
norm: bool = True,
use_layer_norm: bool = True,
tensor_cores: bool = False,
low_memory: bool = False,
**kwargs):
"""
:param num_layers: Number of attention layers
:param fiber_in: Input fiber description
:param fiber_hidden: Hidden fiber description
:param fiber_out: Output fiber description
:param fiber_edge: Input edge fiber description
:param num_heads: Number of attention heads
:param channels_div: Channels division before feeding to attention layer
:param return_type: Return only features of this type
:param pooling: 'avg' or 'max' graph pooling before MLP layers
:param norm: Apply a normalization layer after each attention block
:param use_layer_norm: Apply layer normalization between MLP layers
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
:param low_memory: If True, will use slower ops that use less memory
"""
super().__init__()
self.num_layers = num_layers
self.fiber_edge = fiber_edge
self.num_heads = num_heads
self.channels_div = channels_div
self.return_type = return_type
self.pooling = pooling
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
self.tensor_cores = tensor_cores
self.low_memory = low_memory
if low_memory and not tensor_cores:
logging.warning('Low memory mode will have no effect with no Tensor Cores')
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
graph_modules = []
for i in range(num_layers):
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_hidden,
fiber_edge=fiber_edge,
num_heads=num_heads,
channels_div=channels_div,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=fuse_level))
if norm:
graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_out=fiber_out,
fiber_edge=fiber_edge,
self_interaction=True,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree))
self.graph_modules = Sequential(*graph_modules)
if pooling is not None:
assert return_type is not None, 'return_type must be specified when pooling'
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
edge_feats: Optional[Dict[str, Tensor]] = None,
basis: Optional[Dict[str, Tensor]] = None):
# Compute bases in case they weren't precomputed as part of the data loading
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
use_pad_trick=self.tensor_cores and not self.low_memory,
amp=torch.is_autocast_enabled())
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.tensor_cores and not self.low_memory)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph)
if self.return_type is not None:
return node_feats[str(self.return_type)]
return node_feats
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--num_layers', type=int, default=7,
help='Number of stacked Transformer layers')
parser.add_argument('--num_heads', type=int, default=8,
help='Number of heads in self-attention')
parser.add_argument('--channels_div', type=int, default=2,
help='Channels division before feeding to attention layer')
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
help='Type of graph pooling')
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply a normalization layer after each attention block')
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply layer normalization between MLP layers')
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
help='If true, will use fused ops that are slower but that use less memory '
'(expect 25 percent less memory). '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
return parser
class SE3TransformerPooled(nn.Module):
def __init__(self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
num_degrees: int,
num_channels: int,
output_dim: int,
**kwargs):
super().__init__()
kwargs['pooling'] = kwargs['pooling'] or 'max'
self.transformer = SE3Transformer(
fiber_in=fiber_in,
fiber_hidden=Fiber.create(num_degrees, num_channels),
fiber_out=fiber_out,
fiber_edge=fiber_edge,
return_type=0,
**kwargs
)
n_out_features = fiber_out.num_features
self.mlp = nn.Sequential(
nn.Linear(n_out_features, n_out_features),
nn.ReLU(),
nn.Linear(n_out_features, output_dim)
)
def forward(self, graph, node_feats, edge_feats, basis=None):
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
y = self.mlp(feats).squeeze(-1)
return y
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("Model architecture")
SE3Transformer.add_argparse_args(parser)
parser.add_argument('--num_degrees',
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
type=int, default=4)
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
return parent_parser
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
from typing import Optional, Literal, Dict
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential):
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
def forward(self, input, *args, **kwargs):
for module in self:
input = module(input, *args, **kwargs)
return input
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
""" Add relative positions to existing edge features """
edge_features = edge_features.copy() if edge_features else {}
r = relative_pos.norm(dim=-1, keepdim=True)
if '0' in edge_features:
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
else:
edge_features['0'] = r[..., None]
return edge_features
class SE3Transformer(nn.Module):
def __init__(self,
num_layers: int,
fiber_in: Fiber,
fiber_hidden: Fiber,
fiber_out: Fiber,
num_heads: int,
channels_div: int,
fiber_edge: Fiber = Fiber({}),
return_type: Optional[int] = None,
pooling: Optional[Literal['avg', 'max']] = None,
norm: bool = True,
use_layer_norm: bool = True,
tensor_cores: bool = False,
low_memory: bool = False,
**kwargs):
"""
:param num_layers: Number of attention layers
:param fiber_in: Input fiber description
:param fiber_hidden: Hidden fiber description
:param fiber_out: Output fiber description
:param fiber_edge: Input edge fiber description
:param num_heads: Number of attention heads
:param channels_div: Channels division before feeding to attention layer
:param return_type: Return only features of this type
:param pooling: 'avg' or 'max' graph pooling before MLP layers
:param norm: Apply a normalization layer after each attention block
:param use_layer_norm: Apply layer normalization between MLP layers
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
:param low_memory: If True, will use slower ops that use less memory
"""
super().__init__()
self.num_layers = num_layers
self.fiber_edge = fiber_edge
self.num_heads = num_heads
self.channels_div = channels_div
self.return_type = return_type
self.pooling = pooling
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
self.tensor_cores = tensor_cores
self.low_memory = low_memory
if low_memory:
self.fuse_level = ConvSE3FuseLevel.NONE
else:
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
graph_modules = []
for i in range(num_layers):
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_hidden,
fiber_edge=fiber_edge,
num_heads=num_heads,
channels_div=channels_div,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=self.fuse_level,
low_memory=low_memory))
if norm:
graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_out=fiber_out,
fiber_edge=fiber_edge,
self_interaction=True,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree))
self.graph_modules = Sequential(*graph_modules)
if pooling is not None:
assert return_type is not None, 'return_type must be specified when pooling'
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
edge_feats: Optional[Dict[str, Tensor]] = None,
basis: Optional[Dict[str, Tensor]] = None):
# Compute bases in case they weren't precomputed as part of the data loading
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
use_pad_trick=self.tensor_cores and not self.low_memory,
amp=torch.is_autocast_enabled())
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph)
if self.return_type is not None:
return node_feats[str(self.return_type)]
return node_feats
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--num_layers', type=int, default=7,
help='Number of stacked Transformer layers')
parser.add_argument('--num_heads', type=int, default=8,
help='Number of heads in self-attention')
parser.add_argument('--channels_div', type=int, default=2,
help='Channels division before feeding to attention layer')
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
help='Type of graph pooling')
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply a normalization layer after each attention block')
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply layer normalization between MLP layers')
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
help='If true, will use fused ops that are slower but that use less memory '
'(expect 25 percent less memory). '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
return parser
class SE3TransformerPooled(nn.Module):
def __init__(self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
num_degrees: int,
num_channels: int,
output_dim: int,
**kwargs):
super().__init__()
kwargs['pooling'] = kwargs['pooling'] or 'max'
self.transformer = SE3Transformer(
fiber_in=fiber_in,
fiber_hidden=Fiber.create(num_degrees, num_channels),
fiber_out=fiber_out,
fiber_edge=fiber_edge,
return_type=0,
**kwargs
)
n_out_features = fiber_out.num_features
self.mlp = nn.Sequential(
nn.Linear(n_out_features, n_out_features),
nn.ReLU(),
nn.Linear(n_out_features, output_dim)
)
def forward(self, graph, node_feats, edge_feats, basis=None):
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
y = self.mlp(feats).squeeze(-1)
return y
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("Model architecture")
SE3Transformer.add_argparse_args(parser)
parser.add_argument('--num_degrees',
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
type=int, default=4)
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
return parent_parser

View File

@ -1,70 +1,72 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import pathlib
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import pathlib
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False,
help='Enable W&B logging')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)

View File

@ -32,7 +32,7 @@ from tqdm import tqdm
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback
from se3_transformer.runtime.loggers import DLLogger
from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection
from se3_transformer.runtime.utils import to_cuda, get_local_rank
@ -87,7 +87,10 @@ if __name__ == '__main__':
major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
logger = LoggerCollection(loggers)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
@ -108,6 +111,7 @@ if __name__ == '__main__':
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model,

View File

@ -1,240 +1,241 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import pathlib
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = {
'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
for callback in callbacks:
callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks:
callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch']
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward()
# gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer)
grad_scaler.update()
model.zero_grad(set_to_none=True)
losses.append(loss.item())
return np.mean(losses)
def train(model: nn.Module,
loss_fn: _Loss,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
callbacks: List[BaseCallback],
logger: Logger,
args):
device = torch.cuda.current_device()
model.to(device=device)
local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
if args.optimizer == 'adam':
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
elif args.optimizer == 'lamb':
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
loss = (loss / world_size).item()
logging.info(f'Train loss: {loss}')
logger.log_metrics({'train loss': loss}, epoch_idx)
for callback in callbacks:
callback.on_epoch_end()
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
evaluate(model, val_dataloader, callbacks, args)
model.train()
for callback in callbacks:
callback.on_validation_end(epoch_idx)
if args.save_ckpt_path is not None and not args.benchmark:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
callback.on_fit_end()
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
if __name__ == '__main__':
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |')
logging.info('===============================')
if args.seed is not None:
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
logger = LoggerCollection([
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
])
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
)
loss_fn = nn.L1Loss()
if args.benchmark:
logging.info('Running benchmark mode')
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
if is_distributed:
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
print_parameters_count(model)
logger.log_hyperparams(vars(args))
increase_l2_fetch_granularity()
train(model,
loss_fn,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
callbacks,
logger,
args)
logging.info('Training finished successfully')
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import pathlib
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = {
'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
for callback in callbacks:
callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks:
callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch']
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward()
# gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer)
grad_scaler.update()
model.zero_grad(set_to_none=True)
losses.append(loss.item())
return np.mean(losses)
def train(model: nn.Module,
loss_fn: _Loss,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
callbacks: List[BaseCallback],
logger: Logger,
args):
device = torch.cuda.current_device()
model.to(device=device)
local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
if args.optimizer == 'adam':
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
elif args.optimizer == 'lamb':
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,
args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
loss = (loss / world_size).item()
logging.info(f'Train loss: {loss}')
logger.log_metrics({'train loss': loss}, epoch_idx)
for callback in callbacks:
callback.on_epoch_end()
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and (
(args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
evaluate(model, val_dataloader, callbacks, args)
model.train()
for callback in callbacks:
callback.on_validation_end(epoch_idx)
if args.save_ckpt_path is not None and not args.benchmark:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
callback.on_fit_end()
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
if __name__ == '__main__':
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |')
logging.info('===============================')
if args.seed is not None:
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
logger = LoggerCollection(loggers)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
)
loss_fn = nn.L1Loss()
if args.benchmark:
logging.info('Running benchmark mode')
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
if is_distributed:
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
print_parameters_count(model)
logger.log_hyperparams(vars(args))
increase_l2_fetch_granularity()
train(model,
loss_fn,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
callbacks,
logger,
args)
logging.info('Training finished successfully')

View File

@ -2,9 +2,9 @@ from setuptools import setup, find_packages
setup(
name='se3-transformer',
packages=find_packages(),
packages=find_packages(exclude=['tests']),
include_package_data=True,
version='1.0.0',
version='1.1.0',
description='PyTorch + DGL implementation of SE(3)-Transformers',
author='Alexandre Milesi',
author_email='alexandrem@nvidia.com',

View File

@ -1,102 +1,106 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
TOL = 1e-3
CHANNELS, NODES = 32, 512
def _get_outputs(model, R):
feats0 = torch.randn(NODES, CHANNELS, 1)
feats1 = torch.randn(NODES, CHANNELS, 3)
coords = torch.randn(NODES, 3)
graph = get_random_graph(NODES)
if torch.cuda.is_available():
feats0 = feats0.cuda()
feats1 = feats1.cuda()
R = R.cuda()
coords = coords.cuda()
graph = graph.to('cuda')
model.cuda()
graph1 = assign_relative_pos(graph, coords)
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
graph2 = assign_relative_pos(graph, coords @ R)
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
return out1, out2
def _get_model(**kwargs):
return SE3Transformer(
num_layers=4,
fiber_in=Fiber.create(2, CHANNELS),
fiber_hidden=Fiber.create(3, CHANNELS),
fiber_out=Fiber.create(2, CHANNELS),
fiber_edge=Fiber({}),
num_heads=8,
channels_div=2,
**kwargs
)
def test_equivariance():
model = _get_model()
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
def test_equivariance_pooled():
model = _get_model(pooling='avg', return_type=1)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
def test_invariance_pooled():
model = _get_model(pooling='avg', return_type=0)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, out1, atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
if __package__ is None or __package__ == '':
from utils import get_random_graph, assign_relative_pos, get_max_diff, rot
else:
from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
TOL = 1e-3
CHANNELS, NODES = 32, 512
def _get_outputs(model, R):
feats0 = torch.randn(NODES, CHANNELS, 1)
feats1 = torch.randn(NODES, CHANNELS, 3)
coords = torch.randn(NODES, 3)
graph = get_random_graph(NODES)
if torch.cuda.is_available():
feats0 = feats0.cuda()
feats1 = feats1.cuda()
R = R.cuda()
coords = coords.cuda()
graph = graph.to('cuda')
model.cuda()
graph1 = assign_relative_pos(graph, coords)
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
graph2 = assign_relative_pos(graph, coords @ R)
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
return out1, out2
def _get_model(**kwargs):
return SE3Transformer(
num_layers=4,
fiber_in=Fiber.create(2, CHANNELS),
fiber_hidden=Fiber.create(3, CHANNELS),
fiber_out=Fiber.create(2, CHANNELS),
fiber_edge=Fiber({}),
num_heads=8,
channels_div=2,
**kwargs
)
def test_equivariance():
model = _get_model()
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
def test_equivariance_pooled():
model = _get_model(pooling='avg', return_type=1)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
def test_invariance_pooled():
model = _get_model(pooling='avg', return_type=0)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, out1, atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1, out2)}'

View File

@ -1,60 +1,60 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import torch
def get_random_graph(N, num_edges_factor=18):
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
return graph
def assign_relative_pos(graph, coords):
src, dst = graph.edges()
graph.edata['rel_pos'] = coords[src] - coords[dst]
return graph
def get_max_diff(a, b):
return (a - b).abs().max().item()
def rot_z(gamma):
return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import torch
def get_random_graph(N, num_edges_factor=18):
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
return graph
def assign_relative_pos(graph, coords):
src, dst = graph.edges()
graph.edata['rel_pos'] = coords[src] - coords[dst]
return graph
def get_max_diff(a, b):
return (a - b).abs().max().item()
def rot_z(gamma):
return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

View File

@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided
## Graph Neural Networks
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
## NVIDIA support