[SE3Transformer/DGLPyT] Better low memory mode

This commit is contained in:
Alexandre Milesi 2021-11-01 17:49:17 +01:00
parent dcd3bbac09
commit 22d6621dcd
19 changed files with 1610 additions and 1574 deletions

View file

@ -42,7 +42,6 @@ RUN make -j8
FROM ${FROM_IMAGE_NAME} FROM ${FROM_IMAGE_NAME}
RUN rm -rf /workspace/*
WORKDIR /workspace/se3-transformer WORKDIR /workspace/se3-transformer
# copy built DGL and install it # copy built DGL and install it
@ -55,3 +54,5 @@ ADD . .
ENV DGLBACKEND=pytorch ENV DGLBACKEND=pytorch
ENV OMP_NUM_THREADS=1 ENV OMP_NUM_THREADS=1

View file

@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model:
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers - The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions) - When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met - Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`) - A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode:
| | FP32 | AMP
|---|-----------------------|--------------------------
|`--low_memory false` (default) | 4.7x | 7.1x
|`--low_memory true` | 29.4x | 43.6x
**Self-attention optimizations** **Self-attention optimizations**
@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta
- `--pooling`: Type of graph pooling (default: `max`) - `--pooling`: Type of graph pooling (default: `max`)
- `--norm`: Apply a normalization layer after each attention block (default: `false`) - `--norm`: Apply a normalization layer after each attention block (default: `false`)
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`) - `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`) - `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`)
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`) - `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
- `--num_channels`: Number of channels for the hidden features (default: `32`) - `--num_channels`: Number of channels for the hidden features (default: `32`)
@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`. By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable. You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag.
If no API key is set, `--wandb` will log the run anonymously to Weights & Biases.
**Checkpoints** **Checkpoints**
@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
### Changelog ### Changelog
November 2021:
- Improved low memory mode to give further 6x memory savings
- Disabled W&B logging by default
- Fixed persistent workers when using one data loading process
October 2021: October 2021:
- Updated README performance tables - Updated README performance tables
- Fixed shape mismatch when using partially fused TFNs per output degree - Fixed shape mismatch when using partially fused TFNs per output degree

View file

@ -46,7 +46,8 @@ class DataModule(ABC):
if dist.is_initialized(): if dist.is_initialized():
dist.barrier(device_ids=[get_local_rank()]) dist.barrier(device_ids=[get_local_rank()])
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs} self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
**dataloader_kwargs}
self.ds_train, self.ds_val, self.ds_test = None, None, None self.ds_train, self.ds_val, self.ds_test = None, None, None
def prepare_data(self): def prepare_data(self):

View file

@ -1,2 +1,2 @@
from .transformer import SE3Transformer, SE3TransformerPooled from .transformer import SE3Transformer, SE3TransformerPooled
from .fiber import Fiber from .fiber import Fiber

View file

@ -1,144 +1,144 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from collections import namedtuple from collections import namedtuple
from itertools import product from itertools import product
from typing import Dict from typing import Dict
import torch import torch
from torch import Tensor from torch import Tensor
from se3_transformer.runtime.utils import degree_to_dim from se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels']) FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
class Fiber(dict): class Fiber(dict):
""" """
Describes the structure of some set of features. Describes the structure of some set of features.
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1. Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
Type-0 features: invariant scalars Type-0 features: invariant scalars
Type-1 features: equivariant 3D vectors Type-1 features: equivariant 3D vectors
Type-2 features: equivariant symmetric traceless matrices Type-2 features: equivariant symmetric traceless matrices
... ...
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types. As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
The 'multiplicity' or 'number of channels' is the number of features of a given type. The 'multiplicity' or 'number of channels' is the number of features of a given type.
This class puts together all the degrees and their multiplicities in order to describe This class puts together all the degrees and their multiplicities in order to describe
the inputs, outputs or hidden features of SE3 layers. the inputs, outputs or hidden features of SE3 layers.
""" """
def __init__(self, structure): def __init__(self, structure):
if isinstance(structure, dict): if isinstance(structure, dict):
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])] structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
elif not isinstance(structure[0], FiberEl): elif not isinstance(structure[0], FiberEl):
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1]))) structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
self.structure = structure self.structure = structure
super().__init__({d: m for d, m in self.structure}) super().__init__({d: m for d, m in self.structure})
@property @property
def degrees(self): def degrees(self):
return sorted([t.degree for t in self.structure]) return sorted([t.degree for t in self.structure])
@property @property
def channels(self): def channels(self):
return [self[d] for d in self.degrees] return [self[d] for d in self.degrees]
@property @property
def num_features(self): def num_features(self):
""" Size of the resulting tensor if all features were concatenated together """ """ Size of the resulting tensor if all features were concatenated together """
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure) return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
@staticmethod @staticmethod
def create(num_degrees: int, num_channels: int): def create(num_degrees: int, num_channels: int):
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """ """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
return Fiber([(degree, num_channels) for degree in range(num_degrees)]) return Fiber([(degree, num_channels) for degree in range(num_degrees)])
@staticmethod @staticmethod
def from_features(feats: Dict[str, Tensor]): def from_features(feats: Dict[str, Tensor]):
""" Infer the Fiber structure from a feature dict """ """ Infer the Fiber structure from a feature dict """
structure = {} structure = {}
for k, v in feats.items(): for k, v in feats.items():
degree = int(k) degree = int(k)
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)' assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
assert v.shape[-1] == degree_to_dim(degree) assert v.shape[-1] == degree_to_dim(degree)
structure[degree] = v.shape[-2] structure[degree] = v.shape[-2]
return Fiber(structure) return Fiber(structure)
def __getitem__(self, degree: int): def __getitem__(self, degree: int):
""" fiber[degree] returns the multiplicity for this degree """ """ fiber[degree] returns the multiplicity for this degree """
return dict(self.structure).get(degree, 0) return dict(self.structure).get(degree, 0)
def __iter__(self): def __iter__(self):
""" Iterate over namedtuples (degree, channels) """ """ Iterate over namedtuples (degree, channels) """
return iter(self.structure) return iter(self.structure)
def __mul__(self, other): def __mul__(self, other):
""" """
If other in an int, multiplies all the multiplicities by other. If other in an int, multiplies all the multiplicities by other.
If other is a fiber, returns the cartesian product. If other is a fiber, returns the cartesian product.
""" """
if isinstance(other, Fiber): if isinstance(other, Fiber):
return product(self.structure, other.structure) return product(self.structure, other.structure)
elif isinstance(other, int): elif isinstance(other, int):
return Fiber({t.degree: t.channels * other for t in self.structure}) return Fiber({t.degree: t.channels * other for t in self.structure})
def __add__(self, other): def __add__(self, other):
""" """
If other in an int, add other to all the multiplicities. If other in an int, add other to all the multiplicities.
If other is a fiber, add the multiplicities of the fibers together. If other is a fiber, add the multiplicities of the fibers together.
""" """
if isinstance(other, Fiber): if isinstance(other, Fiber):
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure}) return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
elif isinstance(other, int): elif isinstance(other, int):
return Fiber({t.degree: t.channels + other for t in self.structure}) return Fiber({t.degree: t.channels + other for t in self.structure})
def __repr__(self): def __repr__(self):
return str(self.structure) return str(self.structure)
@staticmethod @staticmethod
def combine_max(f1, f2): def combine_max(f1, f2):
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """ """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
new_dict = dict(f1.structure) new_dict = dict(f1.structure)
for k, m in f2.structure: for k, m in f2.structure:
new_dict[k] = max(new_dict.get(k, 0), m) new_dict[k] = max(new_dict.get(k, 0), m)
return Fiber(list(new_dict.items())) return Fiber(list(new_dict.items()))
@staticmethod @staticmethod
def combine_selectively(f1, f2): def combine_selectively(f1, f2):
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """ """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
# only use orders which occur in fiber f1 # only use orders which occur in fiber f1
new_dict = dict(f1.structure) new_dict = dict(f1.structure)
for k in f1.degrees: for k in f1.degrees:
if k in f2.degrees: if k in f2.degrees:
new_dict[k] += f2[k] new_dict[k] += f2[k]
return Fiber(list(new_dict.items())) return Fiber(list(new_dict.items()))
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int): def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1) # dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
self.degrees] self.degrees]
fibers = torch.cat(fibers, -1) fibers = torch.cat(fibers, -1)
return fibers return fibers

View file

@ -1,5 +1,5 @@
from .linear import LinearSE3 from .linear import LinearSE3
from .norm import NormSE3 from .norm import NormSE3
from .pooling import GPooling from .pooling import GPooling
from .convolution import ConvSE3 from .convolution import ConvSE3
from .attention import AttentionBlockSE3 from .attention import AttentionBlockSE3

View file

@ -1,180 +1,181 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import dgl import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
from torch import Tensor from torch import Tensor
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3 from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range from torch.cuda.nvtx import range as nvtx_range
class AttentionSE3(nn.Module): class AttentionSE3(nn.Module):
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """ """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
key_fiber: Fiber, key_fiber: Fiber,
value_fiber: Fiber value_fiber: Fiber
): ):
""" """
:param num_heads: Number of attention heads :param num_heads: Number of attention heads
:param key_fiber: Fiber for the keys (and also for the queries) :param key_fiber: Fiber for the keys (and also for the queries)
:param value_fiber: Fiber for the values :param value_fiber: Fiber for the values
""" """
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.key_fiber = key_fiber self.key_fiber = key_fiber
self.value_fiber = value_fiber self.value_fiber = value_fiber
def forward( def forward(
self, self,
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused) key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
query: Dict[str, Tensor], # node features query: Dict[str, Tensor], # node features
graph: DGLGraph graph: DGLGraph
): ):
with nvtx_range('AttentionSE3'): with nvtx_range('AttentionSE3'):
with nvtx_range('reshape keys and queries'): with nvtx_range('reshape keys and queries'):
if isinstance(key, Tensor): if isinstance(key, Tensor):
# case where features of all types are fused # case where features of all types are fused
key = key.reshape(key.shape[0], self.num_heads, -1) key = key.reshape(key.shape[0], self.num_heads, -1)
# need to reshape queries that way to keep the same layout as keys # need to reshape queries that way to keep the same layout as keys
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1) out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1) query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
else: else:
# features are not fused, need to fuse and reshape them # features are not fused, need to fuse and reshape them
key = self.key_fiber.to_attention_heads(key, self.num_heads) key = self.key_fiber.to_attention_heads(key, self.num_heads)
query = self.key_fiber.to_attention_heads(query, self.num_heads) query = self.key_fiber.to_attention_heads(query, self.num_heads)
with nvtx_range('attention dot product + softmax'): with nvtx_range('attention dot product + softmax'):
# Compute attention weights (softmax of inner product between key and query) # Compute attention weights (softmax of inner product between key and query)
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1) edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features) edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
edge_weights = edge_softmax(graph, edge_weights) edge_weights = edge_softmax(graph, edge_weights)
edge_weights = edge_weights[..., None, None] edge_weights = edge_weights[..., None, None]
with nvtx_range('weighted sum'): with nvtx_range('weighted sum'):
if isinstance(value, Tensor): if isinstance(value, Tensor):
# features of all types are fused # features of all types are fused
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1]) v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
weights = edge_weights * v weights = edge_weights * v
feat_out = dgl.ops.copy_e_sum(graph, weights) feat_out = dgl.ops.copy_e_sum(graph, weights)
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
out = unfuse_features(feat_out, self.value_fiber.degrees) out = unfuse_features(feat_out, self.value_fiber.degrees)
else: else:
out = {} out = {}
for degree, channels in self.value_fiber: for degree, channels in self.value_fiber:
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads, v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
degree_to_dim(degree)) degree_to_dim(degree))
weights = edge_weights * v weights = edge_weights * v
res = dgl.ops.copy_e_sum(graph, weights) res = dgl.ops.copy_e_sum(graph, weights)
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
return out return out
class AttentionBlockSE3(nn.Module): class AttentionBlockSE3(nn.Module):
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """ """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
def __init__( def __init__(
self, self,
fiber_in: Fiber, fiber_in: Fiber,
fiber_out: Fiber, fiber_out: Fiber,
fiber_edge: Optional[Fiber] = None, fiber_edge: Optional[Fiber] = None,
num_heads: int = 4, num_heads: int = 4,
channels_div: int = 2, channels_div: int = 2,
use_layer_norm: bool = False, use_layer_norm: bool = False,
max_degree: bool = 4, max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
**kwargs low_memory: bool = False,
): **kwargs
""" ):
:param fiber_in: Fiber describing the input features """
:param fiber_out: Fiber describing the output features :param fiber_in: Fiber describing the input features
:param fiber_edge: Fiber describing the edge features (node distances excluded) :param fiber_out: Fiber describing the output features
:param num_heads: Number of attention heads :param fiber_edge: Fiber describing the edge features (node distances excluded)
:param channels_div: Divide the channels by this integer for computing values :param num_heads: Number of attention heads
:param use_layer_norm: Apply layer normalization between MLP layers :param channels_div: Divide the channels by this integer for computing values
:param max_degree: Maximum degree used in the bases computation :param use_layer_norm: Apply layer normalization between MLP layers
:param fuse_level: Maximum fuse level to use in TFN convolutions :param max_degree: Maximum degree used in the bases computation
""" :param fuse_level: Maximum fuse level to use in TFN convolutions
super().__init__() """
if fiber_edge is None: super().__init__()
fiber_edge = Fiber({}) if fiber_edge is None:
self.fiber_in = fiber_in fiber_edge = Fiber({})
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div' self.fiber_in = fiber_in
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) # value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
# (queries are merely projected, hence degrees have to match input) # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) # (queries are merely projected, hence degrees have to match input)
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
allow_fused_output=True) use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
self.to_query = LinearSE3(fiber_in, key_query_fiber) allow_fused_output=True, low_memory=low_memory)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out) self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def forward(
self, def forward(
node_features: Dict[str, Tensor], self,
edge_features: Dict[str, Tensor], node_features: Dict[str, Tensor],
graph: DGLGraph, edge_features: Dict[str, Tensor],
basis: Dict[str, Tensor] graph: DGLGraph,
): basis: Dict[str, Tensor]
with nvtx_range('AttentionBlockSE3'): ):
with nvtx_range('keys / values'): with nvtx_range('AttentionBlockSE3'):
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis) with nvtx_range('keys / values'):
key, value = self._get_key_value_from_fused(fused_key_value) fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
key, value = self._get_key_value_from_fused(fused_key_value)
with nvtx_range('queries'):
query = self.to_query(node_features) with nvtx_range('queries'):
query = self.to_query(node_features)
z = self.attention(value, key, query, graph)
z_concat = aggregate_residual(node_features, z, 'cat') z = self.attention(value, key, query, graph)
return self.project(z_concat) z_concat = aggregate_residual(node_features, z, 'cat')
return self.project(z_concat)
def _get_key_value_from_fused(self, fused_key_value):
# Extract keys and queries features from fused features def _get_key_value_from_fused(self, fused_key_value):
if isinstance(fused_key_value, Tensor): # Extract keys and queries features from fused features
# Previous layer was a fully fused convolution if isinstance(fused_key_value, Tensor):
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2) # Previous layer was a fully fused convolution
else: value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
key, value = {}, {} else:
for degree, feat in fused_key_value.items(): key, value = {}, {}
if int(degree) in self.fiber_in.degrees: for degree, feat in fused_key_value.items():
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2) if int(degree) in self.fiber_in.degrees:
else: value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
value[degree] = feat else:
value[degree] = feat
return key, value
return key, value

View file

@ -1,345 +1,354 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from enum import Enum from enum import Enum
from itertools import product from itertools import product
from typing import Dict from typing import Dict
import dgl import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import DGLGraph import torch.utils.checkpoint
from torch import Tensor from dgl import DGLGraph
from torch.cuda.nvtx import range as nvtx_range from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):
""" class ConvSE3FuseLevel(Enum):
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met. """
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered. Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
A higher level means faster training, but also more memory usage. If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
If you are tight on memory and want to feed large inputs to the network, choose a low value. A higher level means faster training, but also more memory usage.
If you want to train fast, choose a high value. If you are tight on memory and want to feed large inputs to the network, choose a low value.
Recommended value is FULL with AMP. If you want to train fast, choose a high value.
Recommended value is FULL with AMP.
Fully fused TFN convolutions requirements:
- all input channels are the same Fully fused TFN convolutions requirements:
- all output channels are the same - all input channels are the same
- input degrees span the range [0, ..., max_degree] - all output channels are the same
- output degrees span the range [0, ..., max_degree] - input degrees span the range [0, ..., max_degree]
- output degrees span the range [0, ..., max_degree]
Partially fused TFN convolutions requirements:
* For fusing by output degree: Partially fused TFN convolutions requirements:
- all input channels are the same * For fusing by output degree:
- input degrees span the range [0, ..., max_degree] - all input channels are the same
* For fusing by input degree: - input degrees span the range [0, ..., max_degree]
- all output channels are the same * For fusing by input degree:
- output degrees span the range [0, ..., max_degree] - all output channels are the same
- output degrees span the range [0, ..., max_degree]
Original TFN pairwise convolutions: no requirements
""" Original TFN pairwise convolutions: no requirements
"""
FULL = 2
PARTIAL = 1 FULL = 2
NONE = 0 PARTIAL = 1
NONE = 0
class RadialProfile(nn.Module):
""" class RadialProfile(nn.Module):
Radial profile function. """
Outputs weights used to weigh basis matrices in order to get convolution kernels. Radial profile function.
In TFN notation: $R^{l,k}$ Outputs weights used to weigh basis matrices in order to get convolution kernels.
In SE(3)-Transformer notation: $\phi^{l,k}$ In TFN notation: $R^{l,k}$
In SE(3)-Transformer notation: $\phi^{l,k}$
Note:
In the original papers, this function only depends on relative node distances ||x||. Note:
Here, we allow this function to also take as input additional invariant edge features. In the original papers, this function only depends on relative node distances ||x||.
This does not break equivariance and adds expressive power to the model. Here, we allow this function to also take as input additional invariant edge features.
This does not break equivariance and adds expressive power to the model.
Diagram:
invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights Diagram:
""" invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights
"""
def __init__(
self, def __init__(
num_freq: int, self,
channels_in: int, num_freq: int,
channels_out: int, channels_in: int,
edge_dim: int = 1, channels_out: int,
mid_dim: int = 32, edge_dim: int = 1,
use_layer_norm: bool = False mid_dim: int = 32,
): use_layer_norm: bool = False
""" ):
:param num_freq: Number of frequencies """
:param channels_in: Number of input channels :param num_freq: Number of frequencies
:param channels_out: Number of output channels :param channels_in: Number of input channels
:param edge_dim: Number of invariant edge features (input to the radial function) :param channels_out: Number of output channels
:param mid_dim: Size of the hidden MLP layers :param edge_dim: Number of invariant edge features (input to the radial function)
:param use_layer_norm: Apply layer normalization between MLP layers :param mid_dim: Size of the hidden MLP layers
""" :param use_layer_norm: Apply layer normalization between MLP layers
super().__init__() """
modules = [ super().__init__()
nn.Linear(edge_dim, mid_dim), modules = [
nn.LayerNorm(mid_dim) if use_layer_norm else None, nn.Linear(edge_dim, mid_dim),
nn.ReLU(), nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.Linear(mid_dim, mid_dim), nn.ReLU(),
nn.LayerNorm(mid_dim) if use_layer_norm else None, nn.Linear(mid_dim, mid_dim),
nn.ReLU(), nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False) nn.ReLU(),
] nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
]
self.net = nn.Sequential(*[m for m in modules if m is not None])
self.net = nn.Sequential(*[m for m in modules if m is not None])
def forward(self, features: Tensor) -> Tensor:
return self.net(features) def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class VersatileConvSE3(nn.Module):
""" class VersatileConvSE3(nn.Module):
Building block for TFN convolutions. """
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions. Building block for TFN convolutions.
""" This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
"""
def __init__(self,
freq_sum: int, def __init__(self,
channels_in: int, freq_sum: int,
channels_out: int, channels_in: int,
edge_dim: int, channels_out: int,
use_layer_norm: bool, edge_dim: int,
fuse_level: ConvSE3FuseLevel): use_layer_norm: bool,
super().__init__() fuse_level: ConvSE3FuseLevel):
self.freq_sum = freq_sum super().__init__()
self.channels_out = channels_out self.freq_sum = freq_sum
self.channels_in = channels_in self.channels_out = channels_out
self.fuse_level = fuse_level self.channels_in = channels_in
self.radial_func = RadialProfile(num_freq=freq_sum, self.fuse_level = fuse_level
channels_in=channels_in, self.radial_func = RadialProfile(num_freq=freq_sum,
channels_out=channels_out, channels_in=channels_in,
edge_dim=edge_dim, channels_out=channels_out,
use_layer_norm=use_layer_norm) edge_dim=edge_dim,
use_layer_norm=use_layer_norm)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
with nvtx_range(f'VersatileConvSE3'): def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
num_edges = features.shape[0] with nvtx_range(f'VersatileConvSE3'):
in_dim = features.shape[2] num_edges = features.shape[0]
with nvtx_range(f'RadialProfile'): in_dim = features.shape[2]
radial_weights = self.radial_func(invariant_edge_feats) \ with nvtx_range(f'RadialProfile'):
.view(-1, self.channels_out, self.channels_in * self.freq_sum) radial_weights = self.radial_func(invariant_edge_feats) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k if basis is not None:
basis_view = basis.view(num_edges, in_dim, -1) # This block performs the einsum n i l, n o i f, n l f k -> n o k
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1]) basis_view = basis.view(num_edges, in_dim, -1)
return radial_weights @ tmp tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
else: return radial_weights @ tmp
# k = l = 0 non-fused case else:
return radial_weights @ features # k = l = 0 non-fused case
return radial_weights @ features
class ConvSE3(nn.Module):
""" class ConvSE3(nn.Module):
SE(3)-equivariant graph convolution (Tensor Field Network convolution). """
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance. SE(3)-equivariant graph convolution (Tensor Field Network convolution).
Features of different degrees interact together to produce output features. This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
Features of different degrees interact together to produce output features.
Note 1:
The option is given to not pool the output. This means that the convolution sum over neighbors will not be Note 1:
done, and the returned features will be edge features instead of node features. The option is given to not pool the output. This means that the convolution sum over neighbors will not be
done, and the returned features will be edge features instead of node features.
Note 2:
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0. Note 2:
Input edge features are concatenated with input source node features before the kernel is applied. Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
""" Input edge features are concatenated with input source node features before the kernel is applied.
"""
def __init__(
self, def __init__(
fiber_in: Fiber, self,
fiber_out: Fiber, fiber_in: Fiber,
fiber_edge: Fiber, fiber_out: Fiber,
pool: bool = True, fiber_edge: Fiber,
use_layer_norm: bool = False, pool: bool = True,
self_interaction: bool = False, use_layer_norm: bool = False,
max_degree: int = 4, self_interaction: bool = False,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, max_degree: int = 4,
allow_fused_output: bool = False fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
): allow_fused_output: bool = False,
""" low_memory: bool = False
:param fiber_in: Fiber describing the input features ):
:param fiber_out: Fiber describing the output features """
:param fiber_edge: Fiber describing the edge features (node distances excluded) :param fiber_in: Fiber describing the input features
:param pool: If True, compute final node features by averaging incoming edge features :param fiber_out: Fiber describing the output features
:param use_layer_norm: Apply layer normalization between MLP layers :param fiber_edge: Fiber describing the edge features (node distances excluded)
:param self_interaction: Apply self-interaction of nodes :param pool: If True, compute final node features by averaging incoming edge features
:param max_degree: Maximum degree used in the bases computation :param use_layer_norm: Apply layer normalization between MLP layers
:param fuse_level: Maximum fuse level to use in TFN convolutions :param self_interaction: Apply self-interaction of nodes
:param allow_fused_output: Allow the module to output a fused representation of features :param max_degree: Maximum degree used in the bases computation
""" :param fuse_level: Maximum fuse level to use in TFN convolutions
super().__init__() :param allow_fused_output: Allow the module to output a fused representation of features
self.pool = pool """
self.fiber_in = fiber_in super().__init__()
self.fiber_out = fiber_out self.pool = pool
self.self_interaction = self_interaction self.fiber_in = fiber_in
self.max_degree = max_degree self.fiber_out = fiber_out
self.allow_fused_output = allow_fused_output self.self_interaction = self_interaction
self.max_degree = max_degree
# channels_in: account for the concatenation of edge features self.allow_fused_output = allow_fused_output
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in]) self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
channels_out_set = set([f.channels for f in self.fiber_out])
unique_channels_in = (len(channels_in_set) == 1) # channels_in: account for the concatenation of edge features
unique_channels_out = (len(channels_out_set) == 1) channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
degrees_up_to_max = list(range(max_degree + 1)) channels_out_set = set([f.channels for f in self.fiber_out])
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm) unique_channels_in = (len(channels_in_set) == 1)
unique_channels_out = (len(channels_out_set) == 1)
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \ degrees_up_to_max = list(range(max_degree + 1))
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \ common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Single fused convolution if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
self.used_fuse_level = ConvSE3FuseLevel.FULL unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
sum_freq = sum([ # Single fused convolution
degree_to_dim(min(d_in, d_out)) self.used_fuse_level = ConvSE3FuseLevel.FULL
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
]) sum_freq = sum([
degree_to_dim(min(d_in, d_out))
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0], for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
fuse_level=self.used_fuse_level, **common_args) ])
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
unique_channels_in and fiber_in.degrees == degrees_up_to_max: fuse_level=self.used_fuse_level, **common_args)
# Convolutions fused per output degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
self.conv_out = nn.ModuleDict() unique_channels_in and fiber_in.degrees == degrees_up_to_max:
for d_out, c_out in fiber_out: # Convolutions fused per output degree
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees]) self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out, self.conv_out = nn.ModuleDict()
fuse_level=self.used_fuse_level, **common_args) for d_out, c_out in fiber_out:
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \ self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
unique_channels_out and fiber_out.degrees == degrees_up_to_max: fuse_level=self.used_fuse_level, **common_args)
# Convolutions fused per input degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
self.conv_in = nn.ModuleDict() unique_channels_out and fiber_out.degrees == degrees_up_to_max:
for d_in, c_in in fiber_in: # Convolutions fused per input degree
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees]) self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0) self.conv_in = nn.ModuleDict()
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0], for d_in, c_in in fiber_in:
fuse_level=self.used_fuse_level, **common_args) channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
else: sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
# Use pairwise TFN convolutions self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
self.used_fuse_level = ConvSE3FuseLevel.NONE fuse_level=self.used_fuse_level, **common_args)
self.conv = nn.ModuleDict() else:
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out): # Use pairwise TFN convolutions
dict_key = f'{degree_in},{degree_out}' self.used_fuse_level = ConvSE3FuseLevel.NONE
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0) self.conv = nn.ModuleDict()
sum_freq = degree_to_dim(min(degree_in, degree_out)) for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out, dict_key = f'{degree_in},{degree_out}'
fuse_level=self.used_fuse_level, **common_args) channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
sum_freq = degree_to_dim(min(degree_in, degree_out))
if self_interaction: self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
self.to_kernel_self = nn.ParameterDict() fuse_level=self.used_fuse_level, **common_args)
for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]: if self_interaction:
self.to_kernel_self[str(degree_out)] = nn.Parameter( self.to_kernel_self = nn.ParameterDict()
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]:
def _try_unpad(self, feature, basis): self.to_kernel_self[str(degree_out)] = nn.Parameter(
# Account for padded basis torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
if basis is not None:
out_dim = basis.shape[-1] def _try_unpad(self, feature, basis):
out_dim += out_dim % 2 - 1 # Account for padded basis
return feature[..., :out_dim] if basis is not None:
else: out_dim = basis.shape[-1]
return feature out_dim += out_dim % 2 - 1
return feature[..., :out_dim]
def forward( else:
self, return feature
node_feats: Dict[str, Tensor],
edge_feats: Dict[str, Tensor], def forward(
graph: DGLGraph, self,
basis: Dict[str, Tensor] node_feats: Dict[str, Tensor],
): edge_feats: Dict[str, Tensor],
with nvtx_range(f'ConvSE3'): graph: DGLGraph,
invariant_edge_feats = edge_feats['0'].squeeze(-1) basis: Dict[str, Tensor]
src, dst = graph.edges() ):
out = {} with nvtx_range(f'ConvSE3'):
in_features = [] invariant_edge_feats = edge_feats['0'].squeeze(-1)
src, dst = graph.edges()
# Fetch all input features from edge and node features out = {}
for degree_in in self.fiber_in.degrees: in_features = []
src_node_features = node_feats[str(degree_in)][src]
if degree_in > 0 and str(degree_in) in edge_feats: # Fetch all input features from edge and node features
# Handle edge features of any type by concatenating them to node features for degree_in in self.fiber_in.degrees:
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1) src_node_features = node_feats[str(degree_in)][src]
in_features.append(src_node_features) if degree_in > 0 and str(degree_in) in edge_feats:
# Handle edge features of any type by concatenating them to node features
if self.used_fuse_level == ConvSE3FuseLevel.FULL: src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
in_features_fused = torch.cat(in_features, dim=-1) in_features.append(src_node_features)
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
if not self.allow_fused_output or self.self_interaction or self.pool: in_features_fused = torch.cat(in_features, dim=-1)
out = unfuse_features(out, self.fiber_out.degrees) out = self.conv_checkpoint(
self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused']
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'): )
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees: if not self.allow_fused_output or self.self_interaction or self.pool:
basis_used = basis[f'out{degree_out}_fused'] out = unfuse_features(out, self.fiber_out.degrees)
out[str(degree_out)] = self._try_unpad(
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used), elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
basis_used) in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'): basis_used = basis[f'out{degree_out}_fused']
out = 0 out[str(degree_out)] = self._try_unpad(
for degree_in, feature in zip(self.fiber_in.degrees, in_features): self.conv_checkpoint(
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused']) self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used
if not self.allow_fused_output or self.self_interaction or self.pool: ), basis_used)
out = unfuse_features(out, self.fiber_out.degrees)
else: elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
# Fallback to pairwise TFN convolutions out = 0
for degree_out in self.fiber_out.degrees: for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out_feature = 0 out = out + self.conv_checkpoint(
for degree_in, feature in zip(self.fiber_in.degrees, in_features): self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused']
dict_key = f'{degree_in},{degree_out}' )
basis_used = basis.get(dict_key, None) if not self.allow_fused_output or self.self_interaction or self.pool:
out_feature = out_feature + self._try_unpad( out = unfuse_features(out, self.fiber_out.degrees)
self.conv[dict_key](feature, invariant_edge_feats, basis_used), else:
basis_used) # Fallback to pairwise TFN convolutions
out[str(degree_out)] = out_feature for degree_out in self.fiber_out.degrees:
out_feature = 0
for degree_out in self.fiber_out.degrees: for degree_in, feature in zip(self.fiber_in.degrees, in_features):
if self.self_interaction and str(degree_out) in self.to_kernel_self: dict_key = f'{degree_in},{degree_out}'
with nvtx_range(f'self interaction'): basis_used = basis.get(dict_key, None)
dst_features = node_feats[str(degree_out)][dst] out_feature = out_feature + self._try_unpad(
kernel_self = self.to_kernel_self[str(degree_out)] self.conv_checkpoint(
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features self.conv[dict_key], feature, invariant_edge_feats, basis_used
), basis_used)
if self.pool: out[str(degree_out)] = out_feature
with nvtx_range(f'pooling'):
if isinstance(out, dict): for degree_out in self.fiber_out.degrees:
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)]) if self.self_interaction and str(degree_out) in self.to_kernel_self:
else: with nvtx_range(f'self interaction'):
out = dgl.ops.copy_e_sum(graph, out) dst_features = node_feats[str(degree_out)][dst]
return out kernel_self = self.to_kernel_self[str(degree_out)]
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
if self.pool:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_sum(graph, out)
return out

View file

@ -1,59 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from typing import Dict from typing import Dict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module): class LinearSE3(nn.Module):
""" """
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution. Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
Maps a fiber to a fiber with the same degrees (channels may be different). Maps a fiber to a fiber with the same degrees (channels may be different).
No interaction between degrees, but interaction between channels. No interaction between degrees, but interaction between channels.
type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels) type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels)
type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels) type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels)
: :
type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels) type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels)
""" """
def __init__(self, fiber_in: Fiber, fiber_out: Fiber): def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
super().__init__() super().__init__()
self.weights = nn.ParameterDict({ self.weights = nn.ParameterDict({
str(degree_out): nn.Parameter( str(degree_out): nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out])) torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
for degree_out, channels_out in fiber_out for degree_out, channels_out in fiber_out
}) })
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
return { return {
degree: self.weights[degree] @ features[degree] degree: self.weights[degree] @ features[degree]
for degree, weight in self.weights.items() for degree, weight in self.weights.items()
} }

View file

@ -1,83 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from typing import Dict from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range from torch.cuda.nvtx import range as nvtx_range
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module): class NormSE3(nn.Module):
""" """
Norm-based SE(3)-equivariant nonlinearity. Norm-based SE(3)-equivariant nonlinearity.
> feature_norm > LayerNorm() > ReLU() > feature_norm > LayerNorm() > ReLU()
feature_in * > feature_out feature_in * > feature_out
> feature_phase > feature_phase
""" """
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16 NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()): def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__() super().__init__()
self.fiber = fiber self.fiber = fiber
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
if len(set(fiber.channels)) == 1: if len(set(fiber.channels)) == 1:
# Fuse all the layer normalizations into a group normalization # Fuse all the layer normalizations into a group normalization
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels)) self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
else: else:
# Use multiple layer normalizations # Use multiple layer normalizations
self.layer_norms = nn.ModuleDict({ self.layer_norms = nn.ModuleDict({
str(degree): nn.LayerNorm(channels) str(degree): nn.LayerNorm(channels)
for degree, channels in fiber for degree, channels in fiber
}) })
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
with nvtx_range('NormSE3'): with nvtx_range('NormSE3'):
output = {} output = {}
if hasattr(self, 'group_norm'): if hasattr(self, 'group_norm'):
# Compute per-degree norms of features # Compute per-degree norms of features
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
for d in self.fiber.degrees] for d in self.fiber.degrees]
fused_norms = torch.cat(norms, dim=-2) fused_norms = torch.cat(norms, dim=-2)
# Transform the norms only # Transform the norms only
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1) new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2) new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
# Scale features to the new norms # Scale features to the new norms
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees): for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
output[str(d)] = features[str(d)] / norm * new_norm output[str(d)] = features[str(d)] / norm * new_norm
else: else:
for degree, feat in features.items(): for degree, feat in features.items():
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP) norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
output[degree] = new_norm * feat / norm output[degree] = new_norm * feat / norm
return output return output

View file

@ -1,53 +1,53 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from typing import Dict, Literal from typing import Dict, Literal
import torch.nn as nn import torch.nn as nn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn.pytorch import AvgPooling, MaxPooling from dgl.nn.pytorch import AvgPooling, MaxPooling
from torch import Tensor from torch import Tensor
class GPooling(nn.Module): class GPooling(nn.Module):
""" """
Graph max/average pooling on a given feature type. Graph max/average pooling on a given feature type.
The average can be taken for any feature type, and equivariance will be maintained. The average can be taken for any feature type, and equivariance will be maintained.
The maximum can only be taken for invariant features (type 0). The maximum can only be taken for invariant features (type 0).
If you want max-pooling for type > 0 features, look into Vector Neurons. If you want max-pooling for type > 0 features, look into Vector Neurons.
""" """
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'): def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
""" """
:param feat_type: Feature type to pool :param feat_type: Feature type to pool
:param pool: Type of pooling: max or avg :param pool: Type of pooling: max or avg
""" """
super().__init__() super().__init__()
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}' assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance' assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
self.feat_type = feat_type self.feat_type = feat_type
self.pool = MaxPooling() if pool == 'max' else AvgPooling() self.pool = MaxPooling() if pool == 'max' else AvgPooling()
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor: def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
pooled = self.pool(graph, features[str(self.feat_type)]) pooled = self.pool(graph, features[str(self.feat_type)])
return pooled.squeeze(dim=-1) return pooled.squeeze(dim=-1)

View file

@ -1,222 +1,223 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import logging import logging
from typing import Optional, Literal, Dict from typing import Optional, Literal, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl import DGLGraph from dgl import DGLGraph
from torch import Tensor from torch import Tensor
from se3_transformer.model.basis import get_basis, update_basis_with_fused from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3 from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3 from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
def forward(self, input, *args, **kwargs): def forward(self, input, *args, **kwargs):
for module in self: for module in self:
input = module(input, *args, **kwargs) input = module(input, *args, **kwargs)
return input return input
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
""" Add relative positions to existing edge features """ """ Add relative positions to existing edge features """
edge_features = edge_features.copy() if edge_features else {} edge_features = edge_features.copy() if edge_features else {}
r = relative_pos.norm(dim=-1, keepdim=True) r = relative_pos.norm(dim=-1, keepdim=True)
if '0' in edge_features: if '0' in edge_features:
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
else: else:
edge_features['0'] = r[..., None] edge_features['0'] = r[..., None]
return edge_features return edge_features
class SE3Transformer(nn.Module): class SE3Transformer(nn.Module):
def __init__(self, def __init__(self,
num_layers: int, num_layers: int,
fiber_in: Fiber, fiber_in: Fiber,
fiber_hidden: Fiber, fiber_hidden: Fiber,
fiber_out: Fiber, fiber_out: Fiber,
num_heads: int, num_heads: int,
channels_div: int, channels_div: int,
fiber_edge: Fiber = Fiber({}), fiber_edge: Fiber = Fiber({}),
return_type: Optional[int] = None, return_type: Optional[int] = None,
pooling: Optional[Literal['avg', 'max']] = None, pooling: Optional[Literal['avg', 'max']] = None,
norm: bool = True, norm: bool = True,
use_layer_norm: bool = True, use_layer_norm: bool = True,
tensor_cores: bool = False, tensor_cores: bool = False,
low_memory: bool = False, low_memory: bool = False,
**kwargs): **kwargs):
""" """
:param num_layers: Number of attention layers :param num_layers: Number of attention layers
:param fiber_in: Input fiber description :param fiber_in: Input fiber description
:param fiber_hidden: Hidden fiber description :param fiber_hidden: Hidden fiber description
:param fiber_out: Output fiber description :param fiber_out: Output fiber description
:param fiber_edge: Input edge fiber description :param fiber_edge: Input edge fiber description
:param num_heads: Number of attention heads :param num_heads: Number of attention heads
:param channels_div: Channels division before feeding to attention layer :param channels_div: Channels division before feeding to attention layer
:param return_type: Return only features of this type :param return_type: Return only features of this type
:param pooling: 'avg' or 'max' graph pooling before MLP layers :param pooling: 'avg' or 'max' graph pooling before MLP layers
:param norm: Apply a normalization layer after each attention block :param norm: Apply a normalization layer after each attention block
:param use_layer_norm: Apply layer normalization between MLP layers :param use_layer_norm: Apply layer normalization between MLP layers
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
:param low_memory: If True, will use slower ops that use less memory :param low_memory: If True, will use slower ops that use less memory
""" """
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.fiber_edge = fiber_edge self.fiber_edge = fiber_edge
self.num_heads = num_heads self.num_heads = num_heads
self.channels_div = channels_div self.channels_div = channels_div
self.return_type = return_type self.return_type = return_type
self.pooling = pooling self.pooling = pooling
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
self.tensor_cores = tensor_cores self.tensor_cores = tensor_cores
self.low_memory = low_memory self.low_memory = low_memory
if low_memory and not tensor_cores: if low_memory:
logging.warning('Low memory mode will have no effect with no Tensor Cores') self.fuse_level = ConvSE3FuseLevel.NONE
else:
# Fully fused convolutions when using Tensor Cores (and not low memory mode) # Fully fused convolutions when using Tensor Cores (and not low memory mode)
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
graph_modules = [] graph_modules = []
for i in range(num_layers): for i in range(num_layers):
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_hidden, fiber_out=fiber_hidden,
fiber_edge=fiber_edge, fiber_edge=fiber_edge,
num_heads=num_heads, num_heads=num_heads,
channels_div=channels_div, channels_div=channels_div,
use_layer_norm=use_layer_norm, use_layer_norm=use_layer_norm,
max_degree=self.max_degree, max_degree=self.max_degree,
fuse_level=fuse_level)) fuse_level=self.fuse_level,
if norm: low_memory=low_memory))
graph_modules.append(NormSE3(fiber_hidden)) if norm:
fiber_in = fiber_hidden graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_out=fiber_out, graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_edge=fiber_edge, fiber_out=fiber_out,
self_interaction=True, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, self_interaction=True,
max_degree=self.max_degree)) use_layer_norm=use_layer_norm,
self.graph_modules = Sequential(*graph_modules) max_degree=self.max_degree))
self.graph_modules = Sequential(*graph_modules)
if pooling is not None:
assert return_type is not None, 'return_type must be specified when pooling' if pooling is not None:
self.pooling_module = GPooling(pool=pooling, feat_type=return_type) assert return_type is not None, 'return_type must be specified when pooling'
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
edge_feats: Optional[Dict[str, Tensor]] = None, def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
basis: Optional[Dict[str, Tensor]] = None): edge_feats: Optional[Dict[str, Tensor]] = None,
# Compute bases in case they weren't precomputed as part of the data loading basis: Optional[Dict[str, Tensor]] = None):
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, # Compute bases in case they weren't precomputed as part of the data loading
use_pad_trick=self.tensor_cores and not self.low_memory, basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
amp=torch.is_autocast_enabled()) use_pad_trick=self.tensor_cores and not self.low_memory,
amp=torch.is_autocast_enabled())
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, # Add fused bases (per output degree, per input degree, and fully fused) to the dict
fully_fused=self.tensor_cores and not self.low_memory) basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph) if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph)
if self.return_type is not None:
return node_feats[str(self.return_type)] if self.return_type is not None:
return node_feats[str(self.return_type)]
return node_feats
return node_feats
@staticmethod
def add_argparse_args(parser): @staticmethod
parser.add_argument('--num_layers', type=int, default=7, def add_argparse_args(parser):
help='Number of stacked Transformer layers') parser.add_argument('--num_layers', type=int, default=7,
parser.add_argument('--num_heads', type=int, default=8, help='Number of stacked Transformer layers')
help='Number of heads in self-attention') parser.add_argument('--num_heads', type=int, default=8,
parser.add_argument('--channels_div', type=int, default=2, help='Number of heads in self-attention')
help='Channels division before feeding to attention layer') parser.add_argument('--channels_div', type=int, default=2,
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], help='Channels division before feeding to attention layer')
help='Type of graph pooling') parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, help='Type of graph pooling')
help='Apply a normalization layer after each attention block') parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, help='Apply a normalization layer after each attention block')
help='Apply layer normalization between MLP layers') parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, help='Apply layer normalization between MLP layers')
help='If true, will use fused ops that are slower but that use less memory ' parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
'(expect 25 percent less memory). ' help='If true, will use fused ops that are slower but that use less memory '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') '(expect 25 percent less memory). '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
return parser
return parser
class SE3TransformerPooled(nn.Module):
def __init__(self, class SE3TransformerPooled(nn.Module):
fiber_in: Fiber, def __init__(self,
fiber_out: Fiber, fiber_in: Fiber,
fiber_edge: Fiber, fiber_out: Fiber,
num_degrees: int, fiber_edge: Fiber,
num_channels: int, num_degrees: int,
output_dim: int, num_channels: int,
**kwargs): output_dim: int,
super().__init__() **kwargs):
kwargs['pooling'] = kwargs['pooling'] or 'max' super().__init__()
self.transformer = SE3Transformer( kwargs['pooling'] = kwargs['pooling'] or 'max'
fiber_in=fiber_in, self.transformer = SE3Transformer(
fiber_hidden=Fiber.create(num_degrees, num_channels), fiber_in=fiber_in,
fiber_out=fiber_out, fiber_hidden=Fiber.create(num_degrees, num_channels),
fiber_edge=fiber_edge, fiber_out=fiber_out,
return_type=0, fiber_edge=fiber_edge,
**kwargs return_type=0,
) **kwargs
)
n_out_features = fiber_out.num_features
self.mlp = nn.Sequential( n_out_features = fiber_out.num_features
nn.Linear(n_out_features, n_out_features), self.mlp = nn.Sequential(
nn.ReLU(), nn.Linear(n_out_features, n_out_features),
nn.Linear(n_out_features, output_dim) nn.ReLU(),
) nn.Linear(n_out_features, output_dim)
)
def forward(self, graph, node_feats, edge_feats, basis=None):
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) def forward(self, graph, node_feats, edge_feats, basis=None):
y = self.mlp(feats).squeeze(-1) feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
return y y = self.mlp(feats).squeeze(-1)
return y
@staticmethod
def add_argparse_args(parent_parser): @staticmethod
parser = parent_parser.add_argument_group("Model architecture") def add_argparse_args(parent_parser):
SE3Transformer.add_argparse_args(parser) parser = parent_parser.add_argument_group("Model architecture")
parser.add_argument('--num_degrees', SE3Transformer.add_argparse_args(parser)
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', parser.add_argument('--num_degrees',
type=int, default=4) help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) type=int, default=4)
return parent_parser parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
return parent_parser

View file

@ -1,70 +1,72 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import argparse import argparse
import pathlib import pathlib
from se3_transformer.data_loading import QM9DataModule from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled from se3_transformer.model import SE3TransformerPooled
from se3_transformer.runtime.utils import str2bool from se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer') PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths') paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'), paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded') help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'), paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved') help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json', paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file') help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None, paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved') help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None, paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded') help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer') optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam') optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002) optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None) optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9) optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1) optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs') PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size') PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally') PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers') PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision') PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms') PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation') PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs') PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20, PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
help='Do an evaluation round every N epochs') help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False, PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output') help='Minimize stdout output')
PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False,
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False, help='Enable W&B logging')
help='Benchmark mode')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
QM9DataModule.add_argparse_args(PARSER) help='Benchmark mode')
SE3TransformerPooled.add_argparse_args(PARSER)
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)

View file

@ -32,7 +32,7 @@ from tqdm import tqdm
from se3_transformer.runtime import gpu_affinity from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback from se3_transformer.runtime.callbacks import BaseCallback
from se3_transformer.runtime.loggers import DLLogger from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection
from se3_transformer.runtime.utils import to_cuda, get_local_rank from se3_transformer.runtime.utils import to_cuda, get_local_rank
@ -87,7 +87,10 @@ if __name__ == '__main__':
major_cc, minor_cc = torch.cuda.get_device_capability() major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name) loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
logger = LoggerCollection(loggers)
datamodule = QM9DataModule(**vars(args)) datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled( model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
@ -108,6 +111,7 @@ if __name__ == '__main__':
nproc_per_node = torch.cuda.device_count() nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node) affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader() test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model, evaluate(model,

View file

@ -1,240 +1,241 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import logging import logging
import pathlib import pathlib
from typing import List from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from se3_transformer.data_loading import QM9DataModule from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import SE3TransformerPooled from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback PerformanceCallback
from se3_transformer.runtime.inference import evaluate from se3_transformer.runtime.inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """ """ Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0: if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = { checkpoint = {
'state_dict': state_dict, 'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch 'epoch': epoch
} }
for callback in callbacks: for callback in callbacks:
callback.on_checkpoint_save(checkpoint) callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path)) torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}') logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """ """ Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel): if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict']) model.module.load_state_dict(checkpoint['state_dict'])
else: else:
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks: for callback in callbacks:
callback.on_checkpoint_load(checkpoint) callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}') logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch'] return checkpoint['epoch']
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = [] losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch) *inputs, target = to_cuda(batch)
for callback in callbacks: for callback in callbacks:
callback.on_batch_start() callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp): with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs) pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward() grad_scaler.scale(loss).backward()
# gradient accumulation # gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip: if args.gradient_clip:
grad_scaler.unscale_(optimizer) grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer) grad_scaler.step(optimizer)
grad_scaler.update() grad_scaler.update()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
losses.append(loss.item()) losses.append(loss.item())
return np.mean(losses) return np.mean(losses)
def train(model: nn.Module, def train(model: nn.Module,
loss_fn: _Loss, loss_fn: _Loss,
train_dataloader: DataLoader, train_dataloader: DataLoader,
val_dataloader: DataLoader, val_dataloader: DataLoader,
callbacks: List[BaseCallback], callbacks: List[BaseCallback],
logger: Logger, logger: Logger,
args): args):
device = torch.cuda.current_device() device = torch.cuda.current_device()
model.to(device=device) model.to(device=device)
local_rank = get_local_rank() local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1 world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized(): if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model._set_static_graph()
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) model.train()
if args.optimizer == 'adam': grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), if args.optimizer == 'adam':
weight_decay=args.weight_decay) optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
elif args.optimizer == 'lamb': weight_decay=args.weight_decay)
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), elif args.optimizer == 'lamb':
weight_decay=args.weight_decay) optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
else: weight_decay=args.weight_decay)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, else:
weight_decay=args.weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args) for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler): for epoch_idx in range(epoch_start, args.epochs):
train_dataloader.sampler.set_epoch(epoch_idx) if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized(): loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,
loss = torch.tensor(loss, dtype=torch.float, device=device) args)
torch.distributed.all_reduce(loss) if dist.is_initialized():
loss = (loss / world_size).item() loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
logging.info(f'Train loss: {loss}') loss = (loss / world_size).item()
logger.log_metrics({'train loss': loss}, epoch_idx)
logging.info(f'Train loss: {loss}')
for callback in callbacks: logger.log_metrics({'train loss': loss}, epoch_idx)
callback.on_epoch_end()
for callback in callbacks:
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ callback.on_epoch_end()
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
evaluate(model, val_dataloader, callbacks, args)
model.train() if not args.benchmark and (
(args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
for callback in callbacks: evaluate(model, val_dataloader, callbacks, args)
callback.on_validation_end(epoch_idx) model.train()
if args.save_ckpt_path is not None and not args.benchmark: for callback in callbacks:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) callback.on_validation_end(epoch_idx)
for callback in callbacks: if args.save_ckpt_path is not None and not args.benchmark:
callback.on_fit_end() save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
def print_parameters_count(model): callback.on_fit_end()
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == '__main__': logging.info(f'Number of trainable parameters: {num_params_trainable}')
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args() if __name__ == '__main__':
is_distributed = init_distributed()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) local_rank = get_local_rank()
args = PARSER.parse_args()
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |') logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('===============================')
logging.info('====== SE(3)-Transformer ======')
if args.seed is not None: logging.info('| Training procedure |')
logging.info(f'Using seed {args.seed}') logging.info('===============================')
seed_everything(args.seed)
if args.seed is not None:
logger = LoggerCollection([ logging.info(f'Using seed {args.seed}')
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), seed_everything(args.seed)
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
]) loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
if args.wandb:
datamodule = QM9DataModule(**vars(args)) loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
model = SE3TransformerPooled( logger = LoggerCollection(loggers)
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}), datamodule = QM9DataModule(**vars(args))
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), model = SE3TransformerPooled(
output_dim=1, fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
**vars(args) fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
) output_dim=1,
loss_fn = nn.L1Loss() tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
if args.benchmark: )
logging.info('Running benchmark mode') loss_fn = nn.L1Loss()
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] if args.benchmark:
else: logging.info('Running benchmark mode')
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), world_size = dist.get_world_size() if dist.is_initialized() else 1
QM9LRSchedulerCallback(logger, epochs=args.epochs)] callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
if is_distributed: callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) QM9LRSchedulerCallback(logger, epochs=args.epochs)]
print_parameters_count(model) if is_distributed:
logger.log_hyperparams(vars(args)) gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
increase_l2_fetch_granularity()
train(model, print_parameters_count(model)
loss_fn, logger.log_hyperparams(vars(args))
datamodule.train_dataloader(), increase_l2_fetch_granularity()
datamodule.val_dataloader(), train(model,
callbacks, loss_fn,
logger, datamodule.train_dataloader(),
args) datamodule.val_dataloader(),
callbacks,
logging.info('Training finished successfully') logger,
args)
logging.info('Training finished successfully')

View file

@ -2,9 +2,9 @@ from setuptools import setup, find_packages
setup( setup(
name='se3-transformer', name='se3-transformer',
packages=find_packages(), packages=find_packages(exclude=['tests']),
include_package_data=True, include_package_data=True,
version='1.0.0', version='1.1.0',
description='PyTorch + DGL implementation of SE(3)-Transformers', description='PyTorch + DGL implementation of SE(3)-Transformers',
author='Alexandre Milesi', author='Alexandre Milesi',
author_email='alexandrem@nvidia.com', author_email='alexandrem@nvidia.com',

View file

@ -1,102 +1,106 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import torch import torch
from se3_transformer.model import SE3Transformer from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber from se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
if __package__ is None or __package__ == '':
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) ) from utils import get_random_graph, assign_relative_pos, get_max_diff, rot
TOL = 1e-3 else:
CHANNELS, NODES = 32, 512 from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
def _get_outputs(model, R): TOL = 1e-3
feats0 = torch.randn(NODES, CHANNELS, 1) CHANNELS, NODES = 32, 512
feats1 = torch.randn(NODES, CHANNELS, 3)
coords = torch.randn(NODES, 3) def _get_outputs(model, R):
graph = get_random_graph(NODES) feats0 = torch.randn(NODES, CHANNELS, 1)
if torch.cuda.is_available(): feats1 = torch.randn(NODES, CHANNELS, 3)
feats0 = feats0.cuda()
feats1 = feats1.cuda() coords = torch.randn(NODES, 3)
R = R.cuda() graph = get_random_graph(NODES)
coords = coords.cuda() if torch.cuda.is_available():
graph = graph.to('cuda') feats0 = feats0.cuda()
model.cuda() feats1 = feats1.cuda()
R = R.cuda()
graph1 = assign_relative_pos(graph, coords) coords = coords.cuda()
out1 = model(graph1, {'0': feats0, '1': feats1}, {}) graph = graph.to('cuda')
graph2 = assign_relative_pos(graph, coords @ R) model.cuda()
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
graph1 = assign_relative_pos(graph, coords)
return out1, out2 out1 = model(graph1, {'0': feats0, '1': feats1}, {})
graph2 = assign_relative_pos(graph, coords @ R)
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
def _get_model(**kwargs):
return SE3Transformer( return out1, out2
num_layers=4,
fiber_in=Fiber.create(2, CHANNELS),
fiber_hidden=Fiber.create(3, CHANNELS), def _get_model(**kwargs):
fiber_out=Fiber.create(2, CHANNELS), return SE3Transformer(
fiber_edge=Fiber({}), num_layers=4,
num_heads=8, fiber_in=Fiber.create(2, CHANNELS),
channels_div=2, fiber_hidden=Fiber.create(3, CHANNELS),
**kwargs fiber_out=Fiber.create(2, CHANNELS),
) fiber_edge=Fiber({}),
num_heads=8,
channels_div=2,
def test_equivariance(): **kwargs
model = _get_model() )
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda() def test_equivariance():
out1, out2 = _get_outputs(model, R) model = _get_model()
R = rot(*torch.rand(3))
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \ if torch.cuda.is_available():
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}' R = R.cuda()
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \ out1, out2 = _get_outputs(model, R)
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
def test_equivariance_pooled(): assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
model = _get_model(pooling='avg', return_type=1) f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda() def test_equivariance_pooled():
out1, out2 = _get_outputs(model, R) model = _get_model(pooling='avg', return_type=1)
R = rot(*torch.rand(3))
assert torch.allclose(out2, (out1 @ R), atol=TOL), \ if torch.cuda.is_available():
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}' R = R.cuda()
out1, out2 = _get_outputs(model, R)
def test_invariance_pooled(): assert torch.allclose(out2, (out1 @ R), atol=TOL), \
model = _get_model(pooling='avg', return_type=0) f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda() def test_invariance_pooled():
out1, out2 = _get_outputs(model, R) model = _get_model(pooling='avg', return_type=0)
R = rot(*torch.rand(3))
assert torch.allclose(out2, out1, atol=TOL), \ if torch.cuda.is_available():
f'type-0 features should be invariant {get_max_diff(out1, out2)}' R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, out1, atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1, out2)}'

View file

@ -1,60 +1,60 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a # Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"), # copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation # to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, # the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the # and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions: # Software is furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in # The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software. # all copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE. # DEALINGS IN THE SOFTWARE.
# #
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import dgl import dgl
import torch import torch
def get_random_graph(N, num_edges_factor=18): def get_random_graph(N, num_edges_factor=18):
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor)) graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
return graph return graph
def assign_relative_pos(graph, coords): def assign_relative_pos(graph, coords):
src, dst = graph.edges() src, dst = graph.edges()
graph.edata['rel_pos'] = coords[src] - coords[dst] graph.edata['rel_pos'] = coords[src] - coords[dst]
return graph return graph
def get_max_diff(a, b): def get_max_diff(a, b):
return (a - b).abs().max().item() return (a - b).abs().max().item()
def rot_z(gamma): def rot_z(gamma):
return torch.tensor([ return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0], [torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0], [torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1] [0, 0, 1]
], dtype=gamma.dtype) ], dtype=gamma.dtype)
def rot_y(beta): def rot_y(beta):
return torch.tensor([ return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)], [torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0], [0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)] [-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype) ], dtype=beta.dtype)
def rot(alpha, beta, gamma): def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

View file

@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided
## Graph Neural Networks ## Graph Neural Networks
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB | | Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - | | [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
## NVIDIA support ## NVIDIA support