[SE3Transformer/DGLPyT] Better low memory mode
This commit is contained in:
parent
dcd3bbac09
commit
22d6621dcd
|
@ -42,7 +42,6 @@ RUN make -j8
|
||||||
|
|
||||||
FROM ${FROM_IMAGE_NAME}
|
FROM ${FROM_IMAGE_NAME}
|
||||||
|
|
||||||
RUN rm -rf /workspace/*
|
|
||||||
WORKDIR /workspace/se3-transformer
|
WORKDIR /workspace/se3-transformer
|
||||||
|
|
||||||
# copy built DGL and install it
|
# copy built DGL and install it
|
||||||
|
@ -55,3 +54,5 @@ ADD . .
|
||||||
|
|
||||||
ENV DGLBACKEND=pytorch
|
ENV DGLBACKEND=pytorch
|
||||||
ENV OMP_NUM_THREADS=1
|
ENV OMP_NUM_THREADS=1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,13 @@ The following performance optimizations were implemented in this model:
|
||||||
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
||||||
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
||||||
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
||||||
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
|
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`). Overview of memory savings over the official implementation (batch size 100), depending on the precision and the low memory mode:
|
||||||
|
|
||||||
|
| | FP32 | AMP
|
||||||
|
|---|-----------------------|--------------------------
|
||||||
|
|`--low_memory false` (default) | 4.7x | 7.1x
|
||||||
|
|`--low_memory true` | 29.4x | 43.6x
|
||||||
|
|
||||||
|
|
||||||
**Self-attention optimizations**
|
**Self-attention optimizations**
|
||||||
|
|
||||||
|
@ -358,7 +364,7 @@ The complete list of the available parameters for the `training.py` script conta
|
||||||
- `--pooling`: Type of graph pooling (default: `max`)
|
- `--pooling`: Type of graph pooling (default: `max`)
|
||||||
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
||||||
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
||||||
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
|
- `--low_memory`: If true, will use ops that are slower but use less memory (default: `false`)
|
||||||
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
||||||
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
||||||
|
|
||||||
|
@ -407,7 +413,8 @@ The training script is `se3_transformer/runtime/training.py`, to be run as a mod
|
||||||
|
|
||||||
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
||||||
|
|
||||||
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
|
You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the `--wandb` flag.
|
||||||
|
If no API key is set, `--wandb` will log the run anonymously to Weights & Biases.
|
||||||
|
|
||||||
**Checkpoints**
|
**Checkpoints**
|
||||||
|
|
||||||
|
@ -573,6 +580,11 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
||||||
|
|
||||||
### Changelog
|
### Changelog
|
||||||
|
|
||||||
|
November 2021:
|
||||||
|
- Improved low memory mode to give further 6x memory savings
|
||||||
|
- Disabled W&B logging by default
|
||||||
|
- Fixed persistent workers when using one data loading process
|
||||||
|
|
||||||
October 2021:
|
October 2021:
|
||||||
- Updated README performance tables
|
- Updated README performance tables
|
||||||
- Fixed shape mismatch when using partially fused TFNs per output degree
|
- Fixed shape mismatch when using partially fused TFNs per output degree
|
||||||
|
|
|
@ -46,7 +46,8 @@ class DataModule(ABC):
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier(device_ids=[get_local_rank()])
|
dist.barrier(device_ids=[get_local_rank()])
|
||||||
|
|
||||||
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
|
||||||
|
**dataloader_kwargs}
|
||||||
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .transformer import SE3Transformer, SE3TransformerPooled
|
from .transformer import SE3Transformer, SE3TransformerPooled
|
||||||
from .fiber import Fiber
|
from .fiber import Fiber
|
||||||
|
|
|
@ -1,144 +1,144 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.runtime.utils import degree_to_dim
|
from se3_transformer.runtime.utils import degree_to_dim
|
||||||
|
|
||||||
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
||||||
|
|
||||||
|
|
||||||
class Fiber(dict):
|
class Fiber(dict):
|
||||||
"""
|
"""
|
||||||
Describes the structure of some set of features.
|
Describes the structure of some set of features.
|
||||||
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
||||||
Type-0 features: invariant scalars
|
Type-0 features: invariant scalars
|
||||||
Type-1 features: equivariant 3D vectors
|
Type-1 features: equivariant 3D vectors
|
||||||
Type-2 features: equivariant symmetric traceless matrices
|
Type-2 features: equivariant symmetric traceless matrices
|
||||||
...
|
...
|
||||||
|
|
||||||
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
||||||
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
||||||
This class puts together all the degrees and their multiplicities in order to describe
|
This class puts together all the degrees and their multiplicities in order to describe
|
||||||
the inputs, outputs or hidden features of SE3 layers.
|
the inputs, outputs or hidden features of SE3 layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, structure):
|
def __init__(self, structure):
|
||||||
if isinstance(structure, dict):
|
if isinstance(structure, dict):
|
||||||
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
||||||
elif not isinstance(structure[0], FiberEl):
|
elif not isinstance(structure[0], FiberEl):
|
||||||
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
||||||
self.structure = structure
|
self.structure = structure
|
||||||
super().__init__({d: m for d, m in self.structure})
|
super().__init__({d: m for d, m in self.structure})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def degrees(self):
|
def degrees(self):
|
||||||
return sorted([t.degree for t in self.structure])
|
return sorted([t.degree for t in self.structure])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def channels(self):
|
def channels(self):
|
||||||
return [self[d] for d in self.degrees]
|
return [self[d] for d in self.degrees]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_features(self):
|
def num_features(self):
|
||||||
""" Size of the resulting tensor if all features were concatenated together """
|
""" Size of the resulting tensor if all features were concatenated together """
|
||||||
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(num_degrees: int, num_channels: int):
|
def create(num_degrees: int, num_channels: int):
|
||||||
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
||||||
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_features(feats: Dict[str, Tensor]):
|
def from_features(feats: Dict[str, Tensor]):
|
||||||
""" Infer the Fiber structure from a feature dict """
|
""" Infer the Fiber structure from a feature dict """
|
||||||
structure = {}
|
structure = {}
|
||||||
for k, v in feats.items():
|
for k, v in feats.items():
|
||||||
degree = int(k)
|
degree = int(k)
|
||||||
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
||||||
assert v.shape[-1] == degree_to_dim(degree)
|
assert v.shape[-1] == degree_to_dim(degree)
|
||||||
structure[degree] = v.shape[-2]
|
structure[degree] = v.shape[-2]
|
||||||
return Fiber(structure)
|
return Fiber(structure)
|
||||||
|
|
||||||
def __getitem__(self, degree: int):
|
def __getitem__(self, degree: int):
|
||||||
""" fiber[degree] returns the multiplicity for this degree """
|
""" fiber[degree] returns the multiplicity for this degree """
|
||||||
return dict(self.structure).get(degree, 0)
|
return dict(self.structure).get(degree, 0)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
""" Iterate over namedtuples (degree, channels) """
|
""" Iterate over namedtuples (degree, channels) """
|
||||||
return iter(self.structure)
|
return iter(self.structure)
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
"""
|
"""
|
||||||
If other in an int, multiplies all the multiplicities by other.
|
If other in an int, multiplies all the multiplicities by other.
|
||||||
If other is a fiber, returns the cartesian product.
|
If other is a fiber, returns the cartesian product.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, Fiber):
|
if isinstance(other, Fiber):
|
||||||
return product(self.structure, other.structure)
|
return product(self.structure, other.structure)
|
||||||
elif isinstance(other, int):
|
elif isinstance(other, int):
|
||||||
return Fiber({t.degree: t.channels * other for t in self.structure})
|
return Fiber({t.degree: t.channels * other for t in self.structure})
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
"""
|
"""
|
||||||
If other in an int, add other to all the multiplicities.
|
If other in an int, add other to all the multiplicities.
|
||||||
If other is a fiber, add the multiplicities of the fibers together.
|
If other is a fiber, add the multiplicities of the fibers together.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, Fiber):
|
if isinstance(other, Fiber):
|
||||||
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
||||||
elif isinstance(other, int):
|
elif isinstance(other, int):
|
||||||
return Fiber({t.degree: t.channels + other for t in self.structure})
|
return Fiber({t.degree: t.channels + other for t in self.structure})
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.structure)
|
return str(self.structure)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_max(f1, f2):
|
def combine_max(f1, f2):
|
||||||
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
||||||
new_dict = dict(f1.structure)
|
new_dict = dict(f1.structure)
|
||||||
for k, m in f2.structure:
|
for k, m in f2.structure:
|
||||||
new_dict[k] = max(new_dict.get(k, 0), m)
|
new_dict[k] = max(new_dict.get(k, 0), m)
|
||||||
|
|
||||||
return Fiber(list(new_dict.items()))
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_selectively(f1, f2):
|
def combine_selectively(f1, f2):
|
||||||
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
||||||
# only use orders which occur in fiber f1
|
# only use orders which occur in fiber f1
|
||||||
new_dict = dict(f1.structure)
|
new_dict = dict(f1.structure)
|
||||||
for k in f1.degrees:
|
for k in f1.degrees:
|
||||||
if k in f2.degrees:
|
if k in f2.degrees:
|
||||||
new_dict[k] += f2[k]
|
new_dict[k] += f2[k]
|
||||||
return Fiber(list(new_dict.items()))
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
||||||
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
||||||
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
||||||
self.degrees]
|
self.degrees]
|
||||||
fibers = torch.cat(fibers, -1)
|
fibers = torch.cat(fibers, -1)
|
||||||
return fibers
|
return fibers
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .linear import LinearSE3
|
from .linear import LinearSE3
|
||||||
from .norm import NormSE3
|
from .norm import NormSE3
|
||||||
from .pooling import GPooling
|
from .pooling import GPooling
|
||||||
from .convolution import ConvSE3
|
from .convolution import ConvSE3
|
||||||
from .attention import AttentionBlockSE3
|
from .attention import AttentionBlockSE3
|
|
@ -1,180 +1,181 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from dgl.ops import edge_softmax
|
from dgl.ops import edge_softmax
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
from se3_transformer.model.layers.linear import LinearSE3
|
from se3_transformer.model.layers.linear import LinearSE3
|
||||||
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
|
||||||
class AttentionSE3(nn.Module):
|
class AttentionSE3(nn.Module):
|
||||||
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
key_fiber: Fiber,
|
key_fiber: Fiber,
|
||||||
value_fiber: Fiber
|
value_fiber: Fiber
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param num_heads: Number of attention heads
|
:param num_heads: Number of attention heads
|
||||||
:param key_fiber: Fiber for the keys (and also for the queries)
|
:param key_fiber: Fiber for the keys (and also for the queries)
|
||||||
:param value_fiber: Fiber for the values
|
:param value_fiber: Fiber for the values
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.key_fiber = key_fiber
|
self.key_fiber = key_fiber
|
||||||
self.value_fiber = value_fiber
|
self.value_fiber = value_fiber
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
query: Dict[str, Tensor], # node features
|
query: Dict[str, Tensor], # node features
|
||||||
graph: DGLGraph
|
graph: DGLGraph
|
||||||
):
|
):
|
||||||
with nvtx_range('AttentionSE3'):
|
with nvtx_range('AttentionSE3'):
|
||||||
with nvtx_range('reshape keys and queries'):
|
with nvtx_range('reshape keys and queries'):
|
||||||
if isinstance(key, Tensor):
|
if isinstance(key, Tensor):
|
||||||
# case where features of all types are fused
|
# case where features of all types are fused
|
||||||
key = key.reshape(key.shape[0], self.num_heads, -1)
|
key = key.reshape(key.shape[0], self.num_heads, -1)
|
||||||
# need to reshape queries that way to keep the same layout as keys
|
# need to reshape queries that way to keep the same layout as keys
|
||||||
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
||||||
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
||||||
else:
|
else:
|
||||||
# features are not fused, need to fuse and reshape them
|
# features are not fused, need to fuse and reshape them
|
||||||
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
||||||
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
||||||
|
|
||||||
with nvtx_range('attention dot product + softmax'):
|
with nvtx_range('attention dot product + softmax'):
|
||||||
# Compute attention weights (softmax of inner product between key and query)
|
# Compute attention weights (softmax of inner product between key and query)
|
||||||
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
||||||
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
|
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
|
||||||
edge_weights = edge_softmax(graph, edge_weights)
|
edge_weights = edge_softmax(graph, edge_weights)
|
||||||
edge_weights = edge_weights[..., None, None]
|
edge_weights = edge_weights[..., None, None]
|
||||||
|
|
||||||
with nvtx_range('weighted sum'):
|
with nvtx_range('weighted sum'):
|
||||||
if isinstance(value, Tensor):
|
if isinstance(value, Tensor):
|
||||||
# features of all types are fused
|
# features of all types are fused
|
||||||
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
||||||
weights = edge_weights * v
|
weights = edge_weights * v
|
||||||
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
||||||
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
||||||
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
||||||
else:
|
else:
|
||||||
out = {}
|
out = {}
|
||||||
for degree, channels in self.value_fiber:
|
for degree, channels in self.value_fiber:
|
||||||
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
||||||
degree_to_dim(degree))
|
degree_to_dim(degree))
|
||||||
weights = edge_weights * v
|
weights = edge_weights * v
|
||||||
res = dgl.ops.copy_e_sum(graph, weights)
|
res = dgl.ops.copy_e_sum(graph, weights)
|
||||||
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlockSE3(nn.Module):
|
class AttentionBlockSE3(nn.Module):
|
||||||
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fiber_in: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_out: Fiber,
|
fiber_out: Fiber,
|
||||||
fiber_edge: Optional[Fiber] = None,
|
fiber_edge: Optional[Fiber] = None,
|
||||||
num_heads: int = 4,
|
num_heads: int = 4,
|
||||||
channels_div: int = 2,
|
channels_div: int = 2,
|
||||||
use_layer_norm: bool = False,
|
use_layer_norm: bool = False,
|
||||||
max_degree: bool = 4,
|
max_degree: bool = 4,
|
||||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
**kwargs
|
low_memory: bool = False,
|
||||||
):
|
**kwargs
|
||||||
"""
|
):
|
||||||
:param fiber_in: Fiber describing the input features
|
"""
|
||||||
:param fiber_out: Fiber describing the output features
|
:param fiber_in: Fiber describing the input features
|
||||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
:param fiber_out: Fiber describing the output features
|
||||||
:param num_heads: Number of attention heads
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
:param channels_div: Divide the channels by this integer for computing values
|
:param num_heads: Number of attention heads
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param channels_div: Divide the channels by this integer for computing values
|
||||||
:param max_degree: Maximum degree used in the bases computation
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
"""
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
super().__init__()
|
"""
|
||||||
if fiber_edge is None:
|
super().__init__()
|
||||||
fiber_edge = Fiber({})
|
if fiber_edge is None:
|
||||||
self.fiber_in = fiber_in
|
fiber_edge = Fiber({})
|
||||||
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
self.fiber_in = fiber_in
|
||||||
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
||||||
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
||||||
# (queries are merely projected, hence degrees have to match input)
|
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
||||||
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
# (queries are merely projected, hence degrees have to match input)
|
||||||
|
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
||||||
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
|
||||||
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
||||||
allow_fused_output=True)
|
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
||||||
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
allow_fused_output=True, low_memory=low_memory)
|
||||||
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
||||||
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
||||||
|
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
||||||
def forward(
|
|
||||||
self,
|
def forward(
|
||||||
node_features: Dict[str, Tensor],
|
self,
|
||||||
edge_features: Dict[str, Tensor],
|
node_features: Dict[str, Tensor],
|
||||||
graph: DGLGraph,
|
edge_features: Dict[str, Tensor],
|
||||||
basis: Dict[str, Tensor]
|
graph: DGLGraph,
|
||||||
):
|
basis: Dict[str, Tensor]
|
||||||
with nvtx_range('AttentionBlockSE3'):
|
):
|
||||||
with nvtx_range('keys / values'):
|
with nvtx_range('AttentionBlockSE3'):
|
||||||
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
with nvtx_range('keys / values'):
|
||||||
key, value = self._get_key_value_from_fused(fused_key_value)
|
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
||||||
|
key, value = self._get_key_value_from_fused(fused_key_value)
|
||||||
with nvtx_range('queries'):
|
|
||||||
query = self.to_query(node_features)
|
with nvtx_range('queries'):
|
||||||
|
query = self.to_query(node_features)
|
||||||
z = self.attention(value, key, query, graph)
|
|
||||||
z_concat = aggregate_residual(node_features, z, 'cat')
|
z = self.attention(value, key, query, graph)
|
||||||
return self.project(z_concat)
|
z_concat = aggregate_residual(node_features, z, 'cat')
|
||||||
|
return self.project(z_concat)
|
||||||
def _get_key_value_from_fused(self, fused_key_value):
|
|
||||||
# Extract keys and queries features from fused features
|
def _get_key_value_from_fused(self, fused_key_value):
|
||||||
if isinstance(fused_key_value, Tensor):
|
# Extract keys and queries features from fused features
|
||||||
# Previous layer was a fully fused convolution
|
if isinstance(fused_key_value, Tensor):
|
||||||
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
# Previous layer was a fully fused convolution
|
||||||
else:
|
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
||||||
key, value = {}, {}
|
else:
|
||||||
for degree, feat in fused_key_value.items():
|
key, value = {}, {}
|
||||||
if int(degree) in self.fiber_in.degrees:
|
for degree, feat in fused_key_value.items():
|
||||||
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
if int(degree) in self.fiber_in.degrees:
|
||||||
else:
|
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
||||||
value[degree] = feat
|
else:
|
||||||
|
value[degree] = feat
|
||||||
return key, value
|
|
||||||
|
return key, value
|
||||||
|
|
|
@ -1,345 +1,354 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
import torch.utils.checkpoint
|
||||||
from torch import Tensor
|
from dgl import DGLGraph
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch import Tensor
|
||||||
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
from se3_transformer.model.fiber import Fiber
|
|
||||||
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||||
|
|
||||||
class ConvSE3FuseLevel(Enum):
|
|
||||||
"""
|
class ConvSE3FuseLevel(Enum):
|
||||||
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
"""
|
||||||
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
||||||
A higher level means faster training, but also more memory usage.
|
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
||||||
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
A higher level means faster training, but also more memory usage.
|
||||||
If you want to train fast, choose a high value.
|
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
||||||
Recommended value is FULL with AMP.
|
If you want to train fast, choose a high value.
|
||||||
|
Recommended value is FULL with AMP.
|
||||||
Fully fused TFN convolutions requirements:
|
|
||||||
- all input channels are the same
|
Fully fused TFN convolutions requirements:
|
||||||
- all output channels are the same
|
- all input channels are the same
|
||||||
- input degrees span the range [0, ..., max_degree]
|
- all output channels are the same
|
||||||
- output degrees span the range [0, ..., max_degree]
|
- input degrees span the range [0, ..., max_degree]
|
||||||
|
- output degrees span the range [0, ..., max_degree]
|
||||||
Partially fused TFN convolutions requirements:
|
|
||||||
* For fusing by output degree:
|
Partially fused TFN convolutions requirements:
|
||||||
- all input channels are the same
|
* For fusing by output degree:
|
||||||
- input degrees span the range [0, ..., max_degree]
|
- all input channels are the same
|
||||||
* For fusing by input degree:
|
- input degrees span the range [0, ..., max_degree]
|
||||||
- all output channels are the same
|
* For fusing by input degree:
|
||||||
- output degrees span the range [0, ..., max_degree]
|
- all output channels are the same
|
||||||
|
- output degrees span the range [0, ..., max_degree]
|
||||||
Original TFN pairwise convolutions: no requirements
|
|
||||||
"""
|
Original TFN pairwise convolutions: no requirements
|
||||||
|
"""
|
||||||
FULL = 2
|
|
||||||
PARTIAL = 1
|
FULL = 2
|
||||||
NONE = 0
|
PARTIAL = 1
|
||||||
|
NONE = 0
|
||||||
|
|
||||||
class RadialProfile(nn.Module):
|
|
||||||
"""
|
class RadialProfile(nn.Module):
|
||||||
Radial profile function.
|
"""
|
||||||
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
Radial profile function.
|
||||||
In TFN notation: $R^{l,k}$
|
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
||||||
In SE(3)-Transformer notation: $\phi^{l,k}$
|
In TFN notation: $R^{l,k}$
|
||||||
|
In SE(3)-Transformer notation: $\phi^{l,k}$
|
||||||
Note:
|
|
||||||
In the original papers, this function only depends on relative node distances ||x||.
|
Note:
|
||||||
Here, we allow this function to also take as input additional invariant edge features.
|
In the original papers, this function only depends on relative node distances ||x||.
|
||||||
This does not break equivariance and adds expressive power to the model.
|
Here, we allow this function to also take as input additional invariant edge features.
|
||||||
|
This does not break equivariance and adds expressive power to the model.
|
||||||
Diagram:
|
|
||||||
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
Diagram:
|
||||||
"""
|
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
||||||
|
"""
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(
|
||||||
num_freq: int,
|
self,
|
||||||
channels_in: int,
|
num_freq: int,
|
||||||
channels_out: int,
|
channels_in: int,
|
||||||
edge_dim: int = 1,
|
channels_out: int,
|
||||||
mid_dim: int = 32,
|
edge_dim: int = 1,
|
||||||
use_layer_norm: bool = False
|
mid_dim: int = 32,
|
||||||
):
|
use_layer_norm: bool = False
|
||||||
"""
|
):
|
||||||
:param num_freq: Number of frequencies
|
"""
|
||||||
:param channels_in: Number of input channels
|
:param num_freq: Number of frequencies
|
||||||
:param channels_out: Number of output channels
|
:param channels_in: Number of input channels
|
||||||
:param edge_dim: Number of invariant edge features (input to the radial function)
|
:param channels_out: Number of output channels
|
||||||
:param mid_dim: Size of the hidden MLP layers
|
:param edge_dim: Number of invariant edge features (input to the radial function)
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param mid_dim: Size of the hidden MLP layers
|
||||||
"""
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
super().__init__()
|
"""
|
||||||
modules = [
|
super().__init__()
|
||||||
nn.Linear(edge_dim, mid_dim),
|
modules = [
|
||||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
nn.Linear(edge_dim, mid_dim),
|
||||||
nn.ReLU(),
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
nn.Linear(mid_dim, mid_dim),
|
nn.ReLU(),
|
||||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
nn.Linear(mid_dim, mid_dim),
|
||||||
nn.ReLU(),
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
nn.ReLU(),
|
||||||
]
|
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
||||||
|
]
|
||||||
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
|
||||||
|
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
||||||
def forward(self, features: Tensor) -> Tensor:
|
|
||||||
return self.net(features)
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
return self.net(features)
|
||||||
|
|
||||||
class VersatileConvSE3(nn.Module):
|
|
||||||
"""
|
class VersatileConvSE3(nn.Module):
|
||||||
Building block for TFN convolutions.
|
"""
|
||||||
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
Building block for TFN convolutions.
|
||||||
"""
|
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
||||||
|
"""
|
||||||
def __init__(self,
|
|
||||||
freq_sum: int,
|
def __init__(self,
|
||||||
channels_in: int,
|
freq_sum: int,
|
||||||
channels_out: int,
|
channels_in: int,
|
||||||
edge_dim: int,
|
channels_out: int,
|
||||||
use_layer_norm: bool,
|
edge_dim: int,
|
||||||
fuse_level: ConvSE3FuseLevel):
|
use_layer_norm: bool,
|
||||||
super().__init__()
|
fuse_level: ConvSE3FuseLevel):
|
||||||
self.freq_sum = freq_sum
|
super().__init__()
|
||||||
self.channels_out = channels_out
|
self.freq_sum = freq_sum
|
||||||
self.channels_in = channels_in
|
self.channels_out = channels_out
|
||||||
self.fuse_level = fuse_level
|
self.channels_in = channels_in
|
||||||
self.radial_func = RadialProfile(num_freq=freq_sum,
|
self.fuse_level = fuse_level
|
||||||
channels_in=channels_in,
|
self.radial_func = RadialProfile(num_freq=freq_sum,
|
||||||
channels_out=channels_out,
|
channels_in=channels_in,
|
||||||
edge_dim=edge_dim,
|
channels_out=channels_out,
|
||||||
use_layer_norm=use_layer_norm)
|
edge_dim=edge_dim,
|
||||||
|
use_layer_norm=use_layer_norm)
|
||||||
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
|
||||||
with nvtx_range(f'VersatileConvSE3'):
|
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
||||||
num_edges = features.shape[0]
|
with nvtx_range(f'VersatileConvSE3'):
|
||||||
in_dim = features.shape[2]
|
num_edges = features.shape[0]
|
||||||
with nvtx_range(f'RadialProfile'):
|
in_dim = features.shape[2]
|
||||||
radial_weights = self.radial_func(invariant_edge_feats) \
|
with nvtx_range(f'RadialProfile'):
|
||||||
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
radial_weights = self.radial_func(invariant_edge_feats) \
|
||||||
|
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||||
if basis is not None:
|
|
||||||
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
if basis is not None:
|
||||||
basis_view = basis.view(num_edges, in_dim, -1)
|
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||||
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
basis_view = basis.view(num_edges, in_dim, -1)
|
||||||
return radial_weights @ tmp
|
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
||||||
else:
|
return radial_weights @ tmp
|
||||||
# k = l = 0 non-fused case
|
else:
|
||||||
return radial_weights @ features
|
# k = l = 0 non-fused case
|
||||||
|
return radial_weights @ features
|
||||||
|
|
||||||
class ConvSE3(nn.Module):
|
|
||||||
"""
|
class ConvSE3(nn.Module):
|
||||||
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
"""
|
||||||
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
||||||
Features of different degrees interact together to produce output features.
|
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
||||||
|
Features of different degrees interact together to produce output features.
|
||||||
Note 1:
|
|
||||||
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
Note 1:
|
||||||
done, and the returned features will be edge features instead of node features.
|
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
||||||
|
done, and the returned features will be edge features instead of node features.
|
||||||
Note 2:
|
|
||||||
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
Note 2:
|
||||||
Input edge features are concatenated with input source node features before the kernel is applied.
|
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
||||||
"""
|
Input edge features are concatenated with input source node features before the kernel is applied.
|
||||||
|
"""
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(
|
||||||
fiber_in: Fiber,
|
self,
|
||||||
fiber_out: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_edge: Fiber,
|
fiber_out: Fiber,
|
||||||
pool: bool = True,
|
fiber_edge: Fiber,
|
||||||
use_layer_norm: bool = False,
|
pool: bool = True,
|
||||||
self_interaction: bool = False,
|
use_layer_norm: bool = False,
|
||||||
max_degree: int = 4,
|
self_interaction: bool = False,
|
||||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
max_degree: int = 4,
|
||||||
allow_fused_output: bool = False
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
):
|
allow_fused_output: bool = False,
|
||||||
"""
|
low_memory: bool = False
|
||||||
:param fiber_in: Fiber describing the input features
|
):
|
||||||
:param fiber_out: Fiber describing the output features
|
"""
|
||||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
:param fiber_in: Fiber describing the input features
|
||||||
:param pool: If True, compute final node features by averaging incoming edge features
|
:param fiber_out: Fiber describing the output features
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
:param self_interaction: Apply self-interaction of nodes
|
:param pool: If True, compute final node features by averaging incoming edge features
|
||||||
:param max_degree: Maximum degree used in the bases computation
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
:param self_interaction: Apply self-interaction of nodes
|
||||||
:param allow_fused_output: Allow the module to output a fused representation of features
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
"""
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
super().__init__()
|
:param allow_fused_output: Allow the module to output a fused representation of features
|
||||||
self.pool = pool
|
"""
|
||||||
self.fiber_in = fiber_in
|
super().__init__()
|
||||||
self.fiber_out = fiber_out
|
self.pool = pool
|
||||||
self.self_interaction = self_interaction
|
self.fiber_in = fiber_in
|
||||||
self.max_degree = max_degree
|
self.fiber_out = fiber_out
|
||||||
self.allow_fused_output = allow_fused_output
|
self.self_interaction = self_interaction
|
||||||
|
self.max_degree = max_degree
|
||||||
# channels_in: account for the concatenation of edge features
|
self.allow_fused_output = allow_fused_output
|
||||||
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
self.conv_checkpoint = torch.utils.checkpoint.checkpoint if low_memory else lambda m, *x: m(*x)
|
||||||
channels_out_set = set([f.channels for f in self.fiber_out])
|
|
||||||
unique_channels_in = (len(channels_in_set) == 1)
|
# channels_in: account for the concatenation of edge features
|
||||||
unique_channels_out = (len(channels_out_set) == 1)
|
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
||||||
degrees_up_to_max = list(range(max_degree + 1))
|
channels_out_set = set([f.channels for f in self.fiber_out])
|
||||||
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
unique_channels_in = (len(channels_in_set) == 1)
|
||||||
|
unique_channels_out = (len(channels_out_set) == 1)
|
||||||
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
degrees_up_to_max = list(range(max_degree + 1))
|
||||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
||||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
|
||||||
# Single fused convolution
|
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
||||||
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
sum_freq = sum([
|
# Single fused convolution
|
||||||
degree_to_dim(min(d_in, d_out))
|
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
||||||
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
|
||||||
])
|
sum_freq = sum([
|
||||||
|
degree_to_dim(min(d_in, d_out))
|
||||||
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
])
|
||||||
|
|
||||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
||||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
# Convolutions fused per output degree
|
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
self.conv_out = nn.ModuleDict()
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
||||||
for d_out, c_out in fiber_out:
|
# Convolutions fused per output degree
|
||||||
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
self.conv_out = nn.ModuleDict()
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
for d_out, c_out in fiber_out:
|
||||||
|
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
||||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
||||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
# Convolutions fused per input degree
|
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
self.conv_in = nn.ModuleDict()
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
for d_in, c_in in fiber_in:
|
# Convolutions fused per input degree
|
||||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
|
self.conv_in = nn.ModuleDict()
|
||||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
|
for d_in, c_in in fiber_in:
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
|
||||||
else:
|
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||||
# Use pairwise TFN convolutions
|
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
self.conv = nn.ModuleDict()
|
else:
|
||||||
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
# Use pairwise TFN convolutions
|
||||||
dict_key = f'{degree_in},{degree_out}'
|
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
||||||
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
self.conv = nn.ModuleDict()
|
||||||
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
||||||
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
||||||
|
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
||||||
if self_interaction:
|
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
||||||
self.to_kernel_self = nn.ParameterDict()
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
for degree_out, channels_out in fiber_out:
|
|
||||||
if fiber_in[degree_out]:
|
if self_interaction:
|
||||||
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
self.to_kernel_self = nn.ParameterDict()
|
||||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
for degree_out, channels_out in fiber_out:
|
||||||
|
if fiber_in[degree_out]:
|
||||||
def _try_unpad(self, feature, basis):
|
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
||||||
# Account for padded basis
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
if basis is not None:
|
|
||||||
out_dim = basis.shape[-1]
|
def _try_unpad(self, feature, basis):
|
||||||
out_dim += out_dim % 2 - 1
|
# Account for padded basis
|
||||||
return feature[..., :out_dim]
|
if basis is not None:
|
||||||
else:
|
out_dim = basis.shape[-1]
|
||||||
return feature
|
out_dim += out_dim % 2 - 1
|
||||||
|
return feature[..., :out_dim]
|
||||||
def forward(
|
else:
|
||||||
self,
|
return feature
|
||||||
node_feats: Dict[str, Tensor],
|
|
||||||
edge_feats: Dict[str, Tensor],
|
def forward(
|
||||||
graph: DGLGraph,
|
self,
|
||||||
basis: Dict[str, Tensor]
|
node_feats: Dict[str, Tensor],
|
||||||
):
|
edge_feats: Dict[str, Tensor],
|
||||||
with nvtx_range(f'ConvSE3'):
|
graph: DGLGraph,
|
||||||
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
basis: Dict[str, Tensor]
|
||||||
src, dst = graph.edges()
|
):
|
||||||
out = {}
|
with nvtx_range(f'ConvSE3'):
|
||||||
in_features = []
|
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
||||||
|
src, dst = graph.edges()
|
||||||
# Fetch all input features from edge and node features
|
out = {}
|
||||||
for degree_in in self.fiber_in.degrees:
|
in_features = []
|
||||||
src_node_features = node_feats[str(degree_in)][src]
|
|
||||||
if degree_in > 0 and str(degree_in) in edge_feats:
|
# Fetch all input features from edge and node features
|
||||||
# Handle edge features of any type by concatenating them to node features
|
for degree_in in self.fiber_in.degrees:
|
||||||
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
src_node_features = node_feats[str(degree_in)][src]
|
||||||
in_features.append(src_node_features)
|
if degree_in > 0 and str(degree_in) in edge_feats:
|
||||||
|
# Handle edge features of any type by concatenating them to node features
|
||||||
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
||||||
in_features_fused = torch.cat(in_features, dim=-1)
|
in_features.append(src_node_features)
|
||||||
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
|
||||||
|
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
||||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
out = unfuse_features(out, self.fiber_out.degrees)
|
out = self.conv_checkpoint(
|
||||||
|
self.conv, in_features_fused, invariant_edge_feats, basis['fully_fused']
|
||||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
)
|
||||||
in_features_fused = torch.cat(in_features, dim=-1)
|
|
||||||
for degree_out in self.fiber_out.degrees:
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
basis_used = basis[f'out{degree_out}_fused']
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
out[str(degree_out)] = self._try_unpad(
|
|
||||||
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
||||||
basis_used)
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
|
for degree_out in self.fiber_out.degrees:
|
||||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
basis_used = basis[f'out{degree_out}_fused']
|
||||||
out = 0
|
out[str(degree_out)] = self._try_unpad(
|
||||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
self.conv_checkpoint(
|
||||||
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
|
self.conv_out[str(degree_out)], in_features_fused, invariant_edge_feats, basis_used
|
||||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
), basis_used)
|
||||||
out = unfuse_features(out, self.fiber_out.degrees)
|
|
||||||
else:
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||||
# Fallback to pairwise TFN convolutions
|
out = 0
|
||||||
for degree_out in self.fiber_out.degrees:
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
out_feature = 0
|
out = out + self.conv_checkpoint(
|
||||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
self.conv_in[str(degree_in)], feature, invariant_edge_feats, basis[f'in{degree_in}_fused']
|
||||||
dict_key = f'{degree_in},{degree_out}'
|
)
|
||||||
basis_used = basis.get(dict_key, None)
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
out_feature = out_feature + self._try_unpad(
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
|
else:
|
||||||
basis_used)
|
# Fallback to pairwise TFN convolutions
|
||||||
out[str(degree_out)] = out_feature
|
for degree_out in self.fiber_out.degrees:
|
||||||
|
out_feature = 0
|
||||||
for degree_out in self.fiber_out.degrees:
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
with nvtx_range(f'self interaction'):
|
basis_used = basis.get(dict_key, None)
|
||||||
dst_features = node_feats[str(degree_out)][dst]
|
out_feature = out_feature + self._try_unpad(
|
||||||
kernel_self = self.to_kernel_self[str(degree_out)]
|
self.conv_checkpoint(
|
||||||
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
|
self.conv[dict_key], feature, invariant_edge_feats, basis_used
|
||||||
|
), basis_used)
|
||||||
if self.pool:
|
out[str(degree_out)] = out_feature
|
||||||
with nvtx_range(f'pooling'):
|
|
||||||
if isinstance(out, dict):
|
for degree_out in self.fiber_out.degrees:
|
||||||
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
||||||
else:
|
with nvtx_range(f'self interaction'):
|
||||||
out = dgl.ops.copy_e_sum(graph, out)
|
dst_features = node_feats[str(degree_out)][dst]
|
||||||
return out
|
kernel_self = self.to_kernel_self[str(degree_out)]
|
||||||
|
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
|
||||||
|
|
||||||
|
if self.pool:
|
||||||
|
with nvtx_range(f'pooling'):
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
||||||
|
else:
|
||||||
|
out = dgl.ops.copy_e_sum(graph, out)
|
||||||
|
return out
|
||||||
|
|
|
@ -1,59 +1,59 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class LinearSE3(nn.Module):
|
class LinearSE3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
||||||
Maps a fiber to a fiber with the same degrees (channels may be different).
|
Maps a fiber to a fiber with the same degrees (channels may be different).
|
||||||
No interaction between degrees, but interaction between channels.
|
No interaction between degrees, but interaction between channels.
|
||||||
|
|
||||||
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
||||||
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
||||||
:
|
:
|
||||||
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weights = nn.ParameterDict({
|
self.weights = nn.ParameterDict({
|
||||||
str(degree_out): nn.Parameter(
|
str(degree_out): nn.Parameter(
|
||||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
for degree_out, channels_out in fiber_out
|
for degree_out, channels_out in fiber_out
|
||||||
})
|
})
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
return {
|
return {
|
||||||
degree: self.weights[degree] @ features[degree]
|
degree: self.weights[degree] @ features[degree]
|
||||||
for degree, weight in self.weights.items()
|
for degree, weight in self.weights.items()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +1,83 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class NormSE3(nn.Module):
|
class NormSE3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Norm-based SE(3)-equivariant nonlinearity.
|
Norm-based SE(3)-equivariant nonlinearity.
|
||||||
|
|
||||||
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
||||||
feature_in ──┤ * ──> feature_out
|
feature_in ──┤ * ──> feature_out
|
||||||
└──> feature_phase ────────────────────────────┘
|
└──> feature_phase ────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||||
|
|
||||||
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fiber = fiber
|
self.fiber = fiber
|
||||||
self.nonlinearity = nonlinearity
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
if len(set(fiber.channels)) == 1:
|
if len(set(fiber.channels)) == 1:
|
||||||
# Fuse all the layer normalizations into a group normalization
|
# Fuse all the layer normalizations into a group normalization
|
||||||
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
||||||
else:
|
else:
|
||||||
# Use multiple layer normalizations
|
# Use multiple layer normalizations
|
||||||
self.layer_norms = nn.ModuleDict({
|
self.layer_norms = nn.ModuleDict({
|
||||||
str(degree): nn.LayerNorm(channels)
|
str(degree): nn.LayerNorm(channels)
|
||||||
for degree, channels in fiber
|
for degree, channels in fiber
|
||||||
})
|
})
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
with nvtx_range('NormSE3'):
|
with nvtx_range('NormSE3'):
|
||||||
output = {}
|
output = {}
|
||||||
if hasattr(self, 'group_norm'):
|
if hasattr(self, 'group_norm'):
|
||||||
# Compute per-degree norms of features
|
# Compute per-degree norms of features
|
||||||
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
for d in self.fiber.degrees]
|
for d in self.fiber.degrees]
|
||||||
fused_norms = torch.cat(norms, dim=-2)
|
fused_norms = torch.cat(norms, dim=-2)
|
||||||
|
|
||||||
# Transform the norms only
|
# Transform the norms only
|
||||||
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
||||||
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
||||||
|
|
||||||
# Scale features to the new norms
|
# Scale features to the new norms
|
||||||
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
||||||
output[str(d)] = features[str(d)] / norm * new_norm
|
output[str(d)] = features[str(d)] / norm * new_norm
|
||||||
else:
|
else:
|
||||||
for degree, feat in features.items():
|
for degree, feat in features.items():
|
||||||
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
||||||
output[degree] = new_norm * feat / norm
|
output[degree] = new_norm * feat / norm
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,53 +1,53 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from typing import Dict, Literal
|
from typing import Dict, Literal
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
class GPooling(nn.Module):
|
class GPooling(nn.Module):
|
||||||
"""
|
"""
|
||||||
Graph max/average pooling on a given feature type.
|
Graph max/average pooling on a given feature type.
|
||||||
The average can be taken for any feature type, and equivariance will be maintained.
|
The average can be taken for any feature type, and equivariance will be maintained.
|
||||||
The maximum can only be taken for invariant features (type 0).
|
The maximum can only be taken for invariant features (type 0).
|
||||||
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
||||||
"""
|
"""
|
||||||
:param feat_type: Feature type to pool
|
:param feat_type: Feature type to pool
|
||||||
:param pool: Type of pooling: max or avg
|
:param pool: Type of pooling: max or avg
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
||||||
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
||||||
self.feat_type = feat_type
|
self.feat_type = feat_type
|
||||||
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
||||||
pooled = self.pool(graph, features[str(self.feat_type)])
|
pooled = self.pool(graph, features[str(self.feat_type)])
|
||||||
return pooled.squeeze(dim=-1)
|
return pooled.squeeze(dim=-1)
|
||||||
|
|
|
@ -1,222 +1,223 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Literal, Dict
|
from typing import Optional, Literal, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||||
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
from se3_transformer.model.layers.norm import NormSE3
|
from se3_transformer.model.layers.norm import NormSE3
|
||||||
from se3_transformer.model.layers.pooling import GPooling
|
from se3_transformer.model.layers.pooling import GPooling
|
||||||
from se3_transformer.runtime.utils import str2bool
|
from se3_transformer.runtime.utils import str2bool
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class Sequential(nn.Sequential):
|
class Sequential(nn.Sequential):
|
||||||
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
||||||
|
|
||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
for module in self:
|
for module in self:
|
||||||
input = module(input, *args, **kwargs)
|
input = module(input, *args, **kwargs)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
||||||
""" Add relative positions to existing edge features """
|
""" Add relative positions to existing edge features """
|
||||||
edge_features = edge_features.copy() if edge_features else {}
|
edge_features = edge_features.copy() if edge_features else {}
|
||||||
r = relative_pos.norm(dim=-1, keepdim=True)
|
r = relative_pos.norm(dim=-1, keepdim=True)
|
||||||
if '0' in edge_features:
|
if '0' in edge_features:
|
||||||
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
||||||
else:
|
else:
|
||||||
edge_features['0'] = r[..., None]
|
edge_features['0'] = r[..., None]
|
||||||
|
|
||||||
return edge_features
|
return edge_features
|
||||||
|
|
||||||
|
|
||||||
class SE3Transformer(nn.Module):
|
class SE3Transformer(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
fiber_in: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_hidden: Fiber,
|
fiber_hidden: Fiber,
|
||||||
fiber_out: Fiber,
|
fiber_out: Fiber,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
channels_div: int,
|
channels_div: int,
|
||||||
fiber_edge: Fiber = Fiber({}),
|
fiber_edge: Fiber = Fiber({}),
|
||||||
return_type: Optional[int] = None,
|
return_type: Optional[int] = None,
|
||||||
pooling: Optional[Literal['avg', 'max']] = None,
|
pooling: Optional[Literal['avg', 'max']] = None,
|
||||||
norm: bool = True,
|
norm: bool = True,
|
||||||
use_layer_norm: bool = True,
|
use_layer_norm: bool = True,
|
||||||
tensor_cores: bool = False,
|
tensor_cores: bool = False,
|
||||||
low_memory: bool = False,
|
low_memory: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
:param num_layers: Number of attention layers
|
:param num_layers: Number of attention layers
|
||||||
:param fiber_in: Input fiber description
|
:param fiber_in: Input fiber description
|
||||||
:param fiber_hidden: Hidden fiber description
|
:param fiber_hidden: Hidden fiber description
|
||||||
:param fiber_out: Output fiber description
|
:param fiber_out: Output fiber description
|
||||||
:param fiber_edge: Input edge fiber description
|
:param fiber_edge: Input edge fiber description
|
||||||
:param num_heads: Number of attention heads
|
:param num_heads: Number of attention heads
|
||||||
:param channels_div: Channels division before feeding to attention layer
|
:param channels_div: Channels division before feeding to attention layer
|
||||||
:param return_type: Return only features of this type
|
:param return_type: Return only features of this type
|
||||||
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
||||||
:param norm: Apply a normalization layer after each attention block
|
:param norm: Apply a normalization layer after each attention block
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
||||||
:param low_memory: If True, will use slower ops that use less memory
|
:param low_memory: If True, will use slower ops that use less memory
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.fiber_edge = fiber_edge
|
self.fiber_edge = fiber_edge
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.channels_div = channels_div
|
self.channels_div = channels_div
|
||||||
self.return_type = return_type
|
self.return_type = return_type
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
||||||
self.tensor_cores = tensor_cores
|
self.tensor_cores = tensor_cores
|
||||||
self.low_memory = low_memory
|
self.low_memory = low_memory
|
||||||
|
|
||||||
if low_memory and not tensor_cores:
|
if low_memory:
|
||||||
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
self.fuse_level = ConvSE3FuseLevel.NONE
|
||||||
|
else:
|
||||||
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
||||||
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
|
||||||
|
|
||||||
graph_modules = []
|
graph_modules = []
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||||
fiber_out=fiber_hidden,
|
fiber_out=fiber_hidden,
|
||||||
fiber_edge=fiber_edge,
|
fiber_edge=fiber_edge,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
channels_div=channels_div,
|
channels_div=channels_div,
|
||||||
use_layer_norm=use_layer_norm,
|
use_layer_norm=use_layer_norm,
|
||||||
max_degree=self.max_degree,
|
max_degree=self.max_degree,
|
||||||
fuse_level=fuse_level))
|
fuse_level=self.fuse_level,
|
||||||
if norm:
|
low_memory=low_memory))
|
||||||
graph_modules.append(NormSE3(fiber_hidden))
|
if norm:
|
||||||
fiber_in = fiber_hidden
|
graph_modules.append(NormSE3(fiber_hidden))
|
||||||
|
fiber_in = fiber_hidden
|
||||||
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
|
||||||
fiber_out=fiber_out,
|
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
||||||
fiber_edge=fiber_edge,
|
fiber_out=fiber_out,
|
||||||
self_interaction=True,
|
fiber_edge=fiber_edge,
|
||||||
use_layer_norm=use_layer_norm,
|
self_interaction=True,
|
||||||
max_degree=self.max_degree))
|
use_layer_norm=use_layer_norm,
|
||||||
self.graph_modules = Sequential(*graph_modules)
|
max_degree=self.max_degree))
|
||||||
|
self.graph_modules = Sequential(*graph_modules)
|
||||||
if pooling is not None:
|
|
||||||
assert return_type is not None, 'return_type must be specified when pooling'
|
if pooling is not None:
|
||||||
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
assert return_type is not None, 'return_type must be specified when pooling'
|
||||||
|
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
||||||
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
|
||||||
edge_feats: Optional[Dict[str, Tensor]] = None,
|
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
||||||
basis: Optional[Dict[str, Tensor]] = None):
|
edge_feats: Optional[Dict[str, Tensor]] = None,
|
||||||
# Compute bases in case they weren't precomputed as part of the data loading
|
basis: Optional[Dict[str, Tensor]] = None):
|
||||||
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
# Compute bases in case they weren't precomputed as part of the data loading
|
||||||
use_pad_trick=self.tensor_cores and not self.low_memory,
|
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
||||||
amp=torch.is_autocast_enabled())
|
use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
|
amp=torch.is_autocast_enabled())
|
||||||
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
|
||||||
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
||||||
fully_fused=self.tensor_cores and not self.low_memory)
|
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
|
fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
|
||||||
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
|
||||||
|
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
||||||
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
|
||||||
|
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
||||||
if self.pooling is not None:
|
|
||||||
return self.pooling_module(node_feats, graph=graph)
|
if self.pooling is not None:
|
||||||
|
return self.pooling_module(node_feats, graph=graph)
|
||||||
if self.return_type is not None:
|
|
||||||
return node_feats[str(self.return_type)]
|
if self.return_type is not None:
|
||||||
|
return node_feats[str(self.return_type)]
|
||||||
return node_feats
|
|
||||||
|
return node_feats
|
||||||
@staticmethod
|
|
||||||
def add_argparse_args(parser):
|
@staticmethod
|
||||||
parser.add_argument('--num_layers', type=int, default=7,
|
def add_argparse_args(parser):
|
||||||
help='Number of stacked Transformer layers')
|
parser.add_argument('--num_layers', type=int, default=7,
|
||||||
parser.add_argument('--num_heads', type=int, default=8,
|
help='Number of stacked Transformer layers')
|
||||||
help='Number of heads in self-attention')
|
parser.add_argument('--num_heads', type=int, default=8,
|
||||||
parser.add_argument('--channels_div', type=int, default=2,
|
help='Number of heads in self-attention')
|
||||||
help='Channels division before feeding to attention layer')
|
parser.add_argument('--channels_div', type=int, default=2,
|
||||||
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
help='Channels division before feeding to attention layer')
|
||||||
help='Type of graph pooling')
|
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
||||||
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
help='Type of graph pooling')
|
||||||
help='Apply a normalization layer after each attention block')
|
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
help='Apply a normalization layer after each attention block')
|
||||||
help='Apply layer normalization between MLP layers')
|
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
help='Apply layer normalization between MLP layers')
|
||||||
help='If true, will use fused ops that are slower but that use less memory '
|
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
||||||
'(expect 25 percent less memory). '
|
help='If true, will use fused ops that are slower but that use less memory '
|
||||||
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
'(expect 25 percent less memory). '
|
||||||
|
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
||||||
return parser
|
|
||||||
|
return parser
|
||||||
|
|
||||||
class SE3TransformerPooled(nn.Module):
|
|
||||||
def __init__(self,
|
class SE3TransformerPooled(nn.Module):
|
||||||
fiber_in: Fiber,
|
def __init__(self,
|
||||||
fiber_out: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_edge: Fiber,
|
fiber_out: Fiber,
|
||||||
num_degrees: int,
|
fiber_edge: Fiber,
|
||||||
num_channels: int,
|
num_degrees: int,
|
||||||
output_dim: int,
|
num_channels: int,
|
||||||
**kwargs):
|
output_dim: int,
|
||||||
super().__init__()
|
**kwargs):
|
||||||
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
super().__init__()
|
||||||
self.transformer = SE3Transformer(
|
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
||||||
fiber_in=fiber_in,
|
self.transformer = SE3Transformer(
|
||||||
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
fiber_in=fiber_in,
|
||||||
fiber_out=fiber_out,
|
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
||||||
fiber_edge=fiber_edge,
|
fiber_out=fiber_out,
|
||||||
return_type=0,
|
fiber_edge=fiber_edge,
|
||||||
**kwargs
|
return_type=0,
|
||||||
)
|
**kwargs
|
||||||
|
)
|
||||||
n_out_features = fiber_out.num_features
|
|
||||||
self.mlp = nn.Sequential(
|
n_out_features = fiber_out.num_features
|
||||||
nn.Linear(n_out_features, n_out_features),
|
self.mlp = nn.Sequential(
|
||||||
nn.ReLU(),
|
nn.Linear(n_out_features, n_out_features),
|
||||||
nn.Linear(n_out_features, output_dim)
|
nn.ReLU(),
|
||||||
)
|
nn.Linear(n_out_features, output_dim)
|
||||||
|
)
|
||||||
def forward(self, graph, node_feats, edge_feats, basis=None):
|
|
||||||
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
def forward(self, graph, node_feats, edge_feats, basis=None):
|
||||||
y = self.mlp(feats).squeeze(-1)
|
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
||||||
return y
|
y = self.mlp(feats).squeeze(-1)
|
||||||
|
return y
|
||||||
@staticmethod
|
|
||||||
def add_argparse_args(parent_parser):
|
@staticmethod
|
||||||
parser = parent_parser.add_argument_group("Model architecture")
|
def add_argparse_args(parent_parser):
|
||||||
SE3Transformer.add_argparse_args(parser)
|
parser = parent_parser.add_argument_group("Model architecture")
|
||||||
parser.add_argument('--num_degrees',
|
SE3Transformer.add_argparse_args(parser)
|
||||||
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
parser.add_argument('--num_degrees',
|
||||||
type=int, default=4)
|
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
||||||
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
type=int, default=4)
|
||||||
return parent_parser
|
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
||||||
|
return parent_parser
|
||||||
|
|
|
@ -1,70 +1,72 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from se3_transformer.data_loading import QM9DataModule
|
from se3_transformer.data_loading import QM9DataModule
|
||||||
from se3_transformer.model import SE3TransformerPooled
|
from se3_transformer.model import SE3TransformerPooled
|
||||||
from se3_transformer.runtime.utils import str2bool
|
from se3_transformer.runtime.utils import str2bool
|
||||||
|
|
||||||
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
||||||
|
|
||||||
paths = PARSER.add_argument_group('Paths')
|
paths = PARSER.add_argument_group('Paths')
|
||||||
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
||||||
help='Directory where the data is located or should be downloaded')
|
help='Directory where the data is located or should be downloaded')
|
||||||
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
||||||
help='Directory where the results logs should be saved')
|
help='Directory where the results logs should be saved')
|
||||||
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
||||||
help='Name for the resulting DLLogger JSON file')
|
help='Name for the resulting DLLogger JSON file')
|
||||||
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
||||||
help='File where the checkpoint should be saved')
|
help='File where the checkpoint should be saved')
|
||||||
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
||||||
help='File of the checkpoint to be loaded')
|
help='File of the checkpoint to be loaded')
|
||||||
|
|
||||||
optimizer = PARSER.add_argument_group('Optimizer')
|
optimizer = PARSER.add_argument_group('Optimizer')
|
||||||
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
||||||
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
||||||
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
||||||
optimizer.add_argument('--momentum', type=float, default=0.9)
|
optimizer.add_argument('--momentum', type=float, default=0.9)
|
||||||
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
||||||
|
|
||||||
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
||||||
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
||||||
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
||||||
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
||||||
|
|
||||||
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
||||||
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
||||||
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
||||||
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
||||||
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
|
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
|
||||||
help='Do an evaluation round every N epochs')
|
help='Do an evaluation round every N epochs')
|
||||||
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='Minimize stdout output')
|
help='Minimize stdout output')
|
||||||
|
PARSER.add_argument('--wandb', type=str2bool, nargs='?', const=True, default=False,
|
||||||
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
help='Enable W&B logging')
|
||||||
help='Benchmark mode')
|
|
||||||
|
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
||||||
QM9DataModule.add_argparse_args(PARSER)
|
help='Benchmark mode')
|
||||||
SE3TransformerPooled.add_argparse_args(PARSER)
|
|
||||||
|
QM9DataModule.add_argparse_args(PARSER)
|
||||||
|
SE3TransformerPooled.add_argparse_args(PARSER)
|
||||||
|
|
|
@ -32,7 +32,7 @@ from tqdm import tqdm
|
||||||
from se3_transformer.runtime import gpu_affinity
|
from se3_transformer.runtime import gpu_affinity
|
||||||
from se3_transformer.runtime.arguments import PARSER
|
from se3_transformer.runtime.arguments import PARSER
|
||||||
from se3_transformer.runtime.callbacks import BaseCallback
|
from se3_transformer.runtime.callbacks import BaseCallback
|
||||||
from se3_transformer.runtime.loggers import DLLogger
|
from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection
|
||||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,7 +87,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
major_cc, minor_cc = torch.cuda.get_device_capability()
|
major_cc, minor_cc = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
|
||||||
|
if args.wandb:
|
||||||
|
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
|
||||||
|
logger = LoggerCollection(loggers)
|
||||||
datamodule = QM9DataModule(**vars(args))
|
datamodule = QM9DataModule(**vars(args))
|
||||||
model = SE3TransformerPooled(
|
model = SE3TransformerPooled(
|
||||||
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||||
|
@ -108,6 +111,7 @@ if __name__ == '__main__':
|
||||||
nproc_per_node = torch.cuda.device_count()
|
nproc_per_node = torch.cuda.device_count()
|
||||||
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
model._set_static_graph()
|
||||||
|
|
||||||
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
||||||
evaluate(model,
|
evaluate(model,
|
||||||
|
|
|
@ -1,240 +1,241 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from apex.optimizers import FusedAdam, FusedLAMB
|
from apex.optimizers import FusedAdam, FusedLAMB
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from se3_transformer.data_loading import QM9DataModule
|
from se3_transformer.data_loading import QM9DataModule
|
||||||
from se3_transformer.model import SE3TransformerPooled
|
from se3_transformer.model import SE3TransformerPooled
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from se3_transformer.runtime import gpu_affinity
|
from se3_transformer.runtime import gpu_affinity
|
||||||
from se3_transformer.runtime.arguments import PARSER
|
from se3_transformer.runtime.arguments import PARSER
|
||||||
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
||||||
PerformanceCallback
|
PerformanceCallback
|
||||||
from se3_transformer.runtime.inference import evaluate
|
from se3_transformer.runtime.inference import evaluate
|
||||||
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
||||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
||||||
using_tensor_cores, increase_l2_fetch_granularity
|
using_tensor_cores, increase_l2_fetch_granularity
|
||||||
|
|
||||||
|
|
||||||
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
""" Saves model, optimizer and epoch states to path (only once per node) """
|
""" Saves model, optimizer and epoch states to path (only once per node) """
|
||||||
if get_local_rank() == 0:
|
if get_local_rank() == 0:
|
||||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'state_dict': state_dict,
|
'state_dict': state_dict,
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'epoch': epoch
|
'epoch': epoch
|
||||||
}
|
}
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_checkpoint_save(checkpoint)
|
callback.on_checkpoint_save(checkpoint)
|
||||||
|
|
||||||
torch.save(checkpoint, str(path))
|
torch.save(checkpoint, str(path))
|
||||||
logging.info(f'Saved checkpoint to {str(path)}')
|
logging.info(f'Saved checkpoint to {str(path)}')
|
||||||
|
|
||||||
|
|
||||||
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
""" Loads model, optimizer and epoch states from path """
|
""" Loads model, optimizer and epoch states from path """
|
||||||
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
||||||
if isinstance(model, DistributedDataParallel):
|
if isinstance(model, DistributedDataParallel):
|
||||||
model.module.load_state_dict(checkpoint['state_dict'])
|
model.module.load_state_dict(checkpoint['state_dict'])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_checkpoint_load(checkpoint)
|
callback.on_checkpoint_load(checkpoint)
|
||||||
|
|
||||||
logging.info(f'Loaded checkpoint from {str(path)}')
|
logging.info(f'Loaded checkpoint from {str(path)}')
|
||||||
return checkpoint['epoch']
|
return checkpoint['epoch']
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||||
losses = []
|
losses = []
|
||||||
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
||||||
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
||||||
*inputs, target = to_cuda(batch)
|
*inputs, target = to_cuda(batch)
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_batch_start()
|
callback.on_batch_start()
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
pred = model(*inputs)
|
pred = model(*inputs)
|
||||||
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
||||||
|
|
||||||
grad_scaler.scale(loss).backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
# gradient accumulation
|
# gradient accumulation
|
||||||
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
||||||
if args.gradient_clip:
|
if args.gradient_clip:
|
||||||
grad_scaler.unscale_(optimizer)
|
grad_scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
||||||
|
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
return np.mean(losses)
|
return np.mean(losses)
|
||||||
|
|
||||||
|
|
||||||
def train(model: nn.Module,
|
def train(model: nn.Module,
|
||||||
loss_fn: _Loss,
|
loss_fn: _Loss,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
val_dataloader: DataLoader,
|
val_dataloader: DataLoader,
|
||||||
callbacks: List[BaseCallback],
|
callbacks: List[BaseCallback],
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
args):
|
args):
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
model.to(device=device)
|
model.to(device=device)
|
||||||
local_rank = get_local_rank()
|
local_rank = get_local_rank()
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
model._set_static_graph()
|
||||||
model.train()
|
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
model.train()
|
||||||
if args.optimizer == 'adam':
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
if args.optimizer == 'adam':
|
||||||
weight_decay=args.weight_decay)
|
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
elif args.optimizer == 'lamb':
|
weight_decay=args.weight_decay)
|
||||||
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
elif args.optimizer == 'lamb':
|
||||||
weight_decay=args.weight_decay)
|
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
else:
|
weight_decay=args.weight_decay)
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
else:
|
||||||
weight_decay=args.weight_decay)
|
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
|
||||||
|
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
||||||
for callback in callbacks:
|
|
||||||
callback.on_fit_start(optimizer, args)
|
for callback in callbacks:
|
||||||
|
callback.on_fit_start(optimizer, args)
|
||||||
for epoch_idx in range(epoch_start, args.epochs):
|
|
||||||
if isinstance(train_dataloader.sampler, DistributedSampler):
|
for epoch_idx in range(epoch_start, args.epochs):
|
||||||
train_dataloader.sampler.set_epoch(epoch_idx)
|
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
|
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||||
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
|
||||||
if dist.is_initialized():
|
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,
|
||||||
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
args)
|
||||||
torch.distributed.all_reduce(loss)
|
if dist.is_initialized():
|
||||||
loss = (loss / world_size).item()
|
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||||
|
torch.distributed.all_reduce(loss)
|
||||||
logging.info(f'Train loss: {loss}')
|
loss = (loss / world_size).item()
|
||||||
logger.log_metrics({'train loss': loss}, epoch_idx)
|
|
||||||
|
logging.info(f'Train loss: {loss}')
|
||||||
for callback in callbacks:
|
logger.log_metrics({'train loss': loss}, epoch_idx)
|
||||||
callback.on_epoch_end()
|
|
||||||
|
for callback in callbacks:
|
||||||
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
callback.on_epoch_end()
|
||||||
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
|
||||||
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
||||||
|
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
||||||
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
||||||
evaluate(model, val_dataloader, callbacks, args)
|
|
||||||
model.train()
|
if not args.benchmark and (
|
||||||
|
(args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
||||||
for callback in callbacks:
|
evaluate(model, val_dataloader, callbacks, args)
|
||||||
callback.on_validation_end(epoch_idx)
|
model.train()
|
||||||
|
|
||||||
if args.save_ckpt_path is not None and not args.benchmark:
|
for callback in callbacks:
|
||||||
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
callback.on_validation_end(epoch_idx)
|
||||||
|
|
||||||
for callback in callbacks:
|
if args.save_ckpt_path is not None and not args.benchmark:
|
||||||
callback.on_fit_end()
|
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
def print_parameters_count(model):
|
callback.on_fit_end()
|
||||||
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
|
||||||
|
def print_parameters_count(model):
|
||||||
|
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
if __name__ == '__main__':
|
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
||||||
is_distributed = init_distributed()
|
|
||||||
local_rank = get_local_rank()
|
|
||||||
args = PARSER.parse_args()
|
if __name__ == '__main__':
|
||||||
|
is_distributed = init_distributed()
|
||||||
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
local_rank = get_local_rank()
|
||||||
|
args = PARSER.parse_args()
|
||||||
logging.info('====== SE(3)-Transformer ======')
|
|
||||||
logging.info('| Training procedure |')
|
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
||||||
logging.info('===============================')
|
|
||||||
|
logging.info('====== SE(3)-Transformer ======')
|
||||||
if args.seed is not None:
|
logging.info('| Training procedure |')
|
||||||
logging.info(f'Using seed {args.seed}')
|
logging.info('===============================')
|
||||||
seed_everything(args.seed)
|
|
||||||
|
if args.seed is not None:
|
||||||
logger = LoggerCollection([
|
logging.info(f'Using seed {args.seed}')
|
||||||
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
seed_everything(args.seed)
|
||||||
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
|
||||||
])
|
loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
|
||||||
|
if args.wandb:
|
||||||
datamodule = QM9DataModule(**vars(args))
|
loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
|
||||||
model = SE3TransformerPooled(
|
logger = LoggerCollection(loggers)
|
||||||
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
|
||||||
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
datamodule = QM9DataModule(**vars(args))
|
||||||
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
model = SE3TransformerPooled(
|
||||||
output_dim=1,
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||||
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
||||||
**vars(args)
|
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
||||||
)
|
output_dim=1,
|
||||||
loss_fn = nn.L1Loss()
|
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
||||||
|
**vars(args)
|
||||||
if args.benchmark:
|
)
|
||||||
logging.info('Running benchmark mode')
|
loss_fn = nn.L1Loss()
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
||||||
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
if args.benchmark:
|
||||||
else:
|
logging.info('Running benchmark mode')
|
||||||
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
||||||
|
else:
|
||||||
if is_distributed:
|
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
||||||
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
||||||
|
|
||||||
print_parameters_count(model)
|
if is_distributed:
|
||||||
logger.log_hyperparams(vars(args))
|
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
||||||
increase_l2_fetch_granularity()
|
|
||||||
train(model,
|
print_parameters_count(model)
|
||||||
loss_fn,
|
logger.log_hyperparams(vars(args))
|
||||||
datamodule.train_dataloader(),
|
increase_l2_fetch_granularity()
|
||||||
datamodule.val_dataloader(),
|
train(model,
|
||||||
callbacks,
|
loss_fn,
|
||||||
logger,
|
datamodule.train_dataloader(),
|
||||||
args)
|
datamodule.val_dataloader(),
|
||||||
|
callbacks,
|
||||||
logging.info('Training finished successfully')
|
logger,
|
||||||
|
args)
|
||||||
|
|
||||||
|
logging.info('Training finished successfully')
|
||||||
|
|
|
@ -2,9 +2,9 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='se3-transformer',
|
name='se3-transformer',
|
||||||
packages=find_packages(),
|
packages=find_packages(exclude=['tests']),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
version='1.0.0',
|
version='1.1.0',
|
||||||
description='PyTorch + DGL implementation of SE(3)-Transformers',
|
description='PyTorch + DGL implementation of SE(3)-Transformers',
|
||||||
author='Alexandre Milesi',
|
author='Alexandre Milesi',
|
||||||
author_email='alexandrem@nvidia.com',
|
author_email='alexandrem@nvidia.com',
|
||||||
|
|
|
@ -1,102 +1,106 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from se3_transformer.model import SE3Transformer
|
from se3_transformer.model import SE3Transformer
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
|
||||||
|
if __package__ is None or __package__ == '':
|
||||||
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
from utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||||
TOL = 1e-3
|
else:
|
||||||
CHANNELS, NODES = 32, 512
|
from .utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||||
|
|
||||||
|
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
||||||
def _get_outputs(model, R):
|
TOL = 1e-3
|
||||||
feats0 = torch.randn(NODES, CHANNELS, 1)
|
CHANNELS, NODES = 32, 512
|
||||||
feats1 = torch.randn(NODES, CHANNELS, 3)
|
|
||||||
|
|
||||||
coords = torch.randn(NODES, 3)
|
def _get_outputs(model, R):
|
||||||
graph = get_random_graph(NODES)
|
feats0 = torch.randn(NODES, CHANNELS, 1)
|
||||||
if torch.cuda.is_available():
|
feats1 = torch.randn(NODES, CHANNELS, 3)
|
||||||
feats0 = feats0.cuda()
|
|
||||||
feats1 = feats1.cuda()
|
coords = torch.randn(NODES, 3)
|
||||||
R = R.cuda()
|
graph = get_random_graph(NODES)
|
||||||
coords = coords.cuda()
|
if torch.cuda.is_available():
|
||||||
graph = graph.to('cuda')
|
feats0 = feats0.cuda()
|
||||||
model.cuda()
|
feats1 = feats1.cuda()
|
||||||
|
R = R.cuda()
|
||||||
graph1 = assign_relative_pos(graph, coords)
|
coords = coords.cuda()
|
||||||
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
|
graph = graph.to('cuda')
|
||||||
graph2 = assign_relative_pos(graph, coords @ R)
|
model.cuda()
|
||||||
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
|
|
||||||
|
graph1 = assign_relative_pos(graph, coords)
|
||||||
return out1, out2
|
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
|
||||||
|
graph2 = assign_relative_pos(graph, coords @ R)
|
||||||
|
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
|
||||||
def _get_model(**kwargs):
|
|
||||||
return SE3Transformer(
|
return out1, out2
|
||||||
num_layers=4,
|
|
||||||
fiber_in=Fiber.create(2, CHANNELS),
|
|
||||||
fiber_hidden=Fiber.create(3, CHANNELS),
|
def _get_model(**kwargs):
|
||||||
fiber_out=Fiber.create(2, CHANNELS),
|
return SE3Transformer(
|
||||||
fiber_edge=Fiber({}),
|
num_layers=4,
|
||||||
num_heads=8,
|
fiber_in=Fiber.create(2, CHANNELS),
|
||||||
channels_div=2,
|
fiber_hidden=Fiber.create(3, CHANNELS),
|
||||||
**kwargs
|
fiber_out=Fiber.create(2, CHANNELS),
|
||||||
)
|
fiber_edge=Fiber({}),
|
||||||
|
num_heads=8,
|
||||||
|
channels_div=2,
|
||||||
def test_equivariance():
|
**kwargs
|
||||||
model = _get_model()
|
)
|
||||||
R = rot(*torch.rand(3))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
R = R.cuda()
|
def test_equivariance():
|
||||||
out1, out2 = _get_outputs(model, R)
|
model = _get_model()
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
|
if torch.cuda.is_available():
|
||||||
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
|
R = R.cuda()
|
||||||
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
|
out1, out2 = _get_outputs(model, R)
|
||||||
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
|
|
||||||
|
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
|
||||||
|
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
|
||||||
def test_equivariance_pooled():
|
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
|
||||||
model = _get_model(pooling='avg', return_type=1)
|
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
|
||||||
R = rot(*torch.rand(3))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
R = R.cuda()
|
def test_equivariance_pooled():
|
||||||
out1, out2 = _get_outputs(model, R)
|
model = _get_model(pooling='avg', return_type=1)
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
|
if torch.cuda.is_available():
|
||||||
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
|
R = R.cuda()
|
||||||
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
def test_invariance_pooled():
|
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
|
||||||
model = _get_model(pooling='avg', return_type=0)
|
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
|
||||||
R = rot(*torch.rand(3))
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
R = R.cuda()
|
def test_invariance_pooled():
|
||||||
out1, out2 = _get_outputs(model, R)
|
model = _get_model(pooling='avg', return_type=0)
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
assert torch.allclose(out2, out1, atol=TOL), \
|
if torch.cuda.is_available():
|
||||||
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
|
R = R.cuda()
|
||||||
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
|
assert torch.allclose(out2, out1, atol=TOL), \
|
||||||
|
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
|
||||||
|
|
|
@ -1,60 +1,60 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_random_graph(N, num_edges_factor=18):
|
def get_random_graph(N, num_edges_factor=18):
|
||||||
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
|
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def assign_relative_pos(graph, coords):
|
def assign_relative_pos(graph, coords):
|
||||||
src, dst = graph.edges()
|
src, dst = graph.edges()
|
||||||
graph.edata['rel_pos'] = coords[src] - coords[dst]
|
graph.edata['rel_pos'] = coords[src] - coords[dst]
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def get_max_diff(a, b):
|
def get_max_diff(a, b):
|
||||||
return (a - b).abs().max().item()
|
return (a - b).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
def rot_z(gamma):
|
def rot_z(gamma):
|
||||||
return torch.tensor([
|
return torch.tensor([
|
||||||
[torch.cos(gamma), -torch.sin(gamma), 0],
|
[torch.cos(gamma), -torch.sin(gamma), 0],
|
||||||
[torch.sin(gamma), torch.cos(gamma), 0],
|
[torch.sin(gamma), torch.cos(gamma), 0],
|
||||||
[0, 0, 1]
|
[0, 0, 1]
|
||||||
], dtype=gamma.dtype)
|
], dtype=gamma.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rot_y(beta):
|
def rot_y(beta):
|
||||||
return torch.tensor([
|
return torch.tensor([
|
||||||
[torch.cos(beta), 0, torch.sin(beta)],
|
[torch.cos(beta), 0, torch.sin(beta)],
|
||||||
[0, 1, 0],
|
[0, 1, 0],
|
||||||
[-torch.sin(beta), 0, torch.cos(beta)]
|
[-torch.sin(beta), 0, torch.cos(beta)]
|
||||||
], dtype=beta.dtype)
|
], dtype=beta.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rot(alpha, beta, gamma):
|
def rot(alpha, beta, gamma):
|
||||||
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
||||||
|
|
|
@ -83,7 +83,7 @@ These examples, along with our NVIDIA deep learning software stack, are provided
|
||||||
## Graph Neural Networks
|
## Graph Neural Networks
|
||||||
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
|
| Models | Framework | A100 | AMP | Multi-GPU | Multi-Node | TRT | ONNX | Triton | DLC | NB |
|
||||||
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
|
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |------------- |------------- |------------- |------------- |------------- |
|
||||||
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
|
| [SE(3)-Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/DGLPyTorch/DrugDiscovery/SE3Transformer) | PyTorch | Yes | Yes | Yes | - | - | - | - | - | - |
|
||||||
|
|
||||||
|
|
||||||
## NVIDIA support
|
## NVIDIA support
|
||||||
|
|
Loading…
Reference in a new issue