[SE3Transformer/DGLPyT] Fix multiple shapes mismatch with specific hyperparams
This commit is contained in:
parent
028534a5b9
commit
9b86d8c0b8
|
@ -328,7 +328,7 @@ The complete list of the available parameters for the `training.py` script conta
|
||||||
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
||||||
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
||||||
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
||||||
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
|
- `--eval_interval`: Do an evaluation round every N epochs (default: `20`)
|
||||||
- `--silent`: Minimize stdout output (default: `false`)
|
- `--silent`: Minimize stdout output (default: `false`)
|
||||||
|
|
||||||
**Paths**
|
**Paths**
|
||||||
|
@ -485,6 +485,7 @@ Our results were obtained by running the `scripts/train.sh` training script in t
|
||||||
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### Training performance results
|
#### Training performance results
|
||||||
|
|
||||||
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
||||||
|
@ -495,8 +496,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|
||||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
||||||
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
||||||
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
|
| 8 | 240 | 15.88 | 21.02 | 1.32x | 7.18 | 7.20 |
|
||||||
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
|
| 8 | 120 | 12.68 | 13.99 | 1.10x | 7.00 | 6.86 |
|
||||||
|
|
||||||
|
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
@ -510,8 +511,8 @@ Our results were obtained by running the `scripts/benchmark_train.sh` and `scrip
|
||||||
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
||||||
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
||||||
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
|
| 8 | 240 | 8.68 | 12.75 | 1.47x | 6.94 | 6.78 |
|
||||||
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
|
| 8 | 120 | 6.64 | 8.58 | 1.29x | 6.44 | 6.08 |
|
||||||
|
|
||||||
|
|
||||||
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
@ -572,6 +573,15 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
|
||||||
|
|
||||||
### Changelog
|
### Changelog
|
||||||
|
|
||||||
|
October 2021:
|
||||||
|
- Updated README performance tables
|
||||||
|
- Fixed shape mismatch when using partially fused TFNs per output degree
|
||||||
|
- Fixed shape mismatch when using partially fused TFNs per input degree with edge degrees > 0
|
||||||
|
|
||||||
|
September 2021:
|
||||||
|
- Moved to new location (from `PyTorch/DrugDiscovery` to `DGLPyTorch/DrugDiscovery`)
|
||||||
|
- Fixed multi-GPUs training script
|
||||||
|
|
||||||
August 2021
|
August 2021
|
||||||
- Initial release
|
- Initial release
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .transformer import SE3Transformer, SE3TransformerPooled
|
from .transformer import SE3Transformer, SE3TransformerPooled
|
||||||
from .fiber import Fiber
|
from .fiber import Fiber
|
||||||
|
|
|
@ -54,9 +54,8 @@ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
||||||
|
|
||||||
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
||||||
all_degrees = list(range(2 * max_degree + 1))
|
all_degrees = list(range(2 * max_degree + 1))
|
||||||
with nvtx_range('spherical harmonics'):
|
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
||||||
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
||||||
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
|
|
|
@ -1,144 +1,144 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.runtime.utils import degree_to_dim
|
from se3_transformer.runtime.utils import degree_to_dim
|
||||||
|
|
||||||
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
||||||
|
|
||||||
|
|
||||||
class Fiber(dict):
|
class Fiber(dict):
|
||||||
"""
|
"""
|
||||||
Describes the structure of some set of features.
|
Describes the structure of some set of features.
|
||||||
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
||||||
Type-0 features: invariant scalars
|
Type-0 features: invariant scalars
|
||||||
Type-1 features: equivariant 3D vectors
|
Type-1 features: equivariant 3D vectors
|
||||||
Type-2 features: equivariant symmetric traceless matrices
|
Type-2 features: equivariant symmetric traceless matrices
|
||||||
...
|
...
|
||||||
|
|
||||||
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
||||||
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
||||||
This class puts together all the degrees and their multiplicities in order to describe
|
This class puts together all the degrees and their multiplicities in order to describe
|
||||||
the inputs, outputs or hidden features of SE3 layers.
|
the inputs, outputs or hidden features of SE3 layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, structure):
|
def __init__(self, structure):
|
||||||
if isinstance(structure, dict):
|
if isinstance(structure, dict):
|
||||||
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
||||||
elif not isinstance(structure[0], FiberEl):
|
elif not isinstance(structure[0], FiberEl):
|
||||||
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
||||||
self.structure = structure
|
self.structure = structure
|
||||||
super().__init__({d: m for d, m in self.structure})
|
super().__init__({d: m for d, m in self.structure})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def degrees(self):
|
def degrees(self):
|
||||||
return sorted([t.degree for t in self.structure])
|
return sorted([t.degree for t in self.structure])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def channels(self):
|
def channels(self):
|
||||||
return [self[d] for d in self.degrees]
|
return [self[d] for d in self.degrees]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_features(self):
|
def num_features(self):
|
||||||
""" Size of the resulting tensor if all features were concatenated together """
|
""" Size of the resulting tensor if all features were concatenated together """
|
||||||
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(num_degrees: int, num_channels: int):
|
def create(num_degrees: int, num_channels: int):
|
||||||
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
||||||
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_features(feats: Dict[str, Tensor]):
|
def from_features(feats: Dict[str, Tensor]):
|
||||||
""" Infer the Fiber structure from a feature dict """
|
""" Infer the Fiber structure from a feature dict """
|
||||||
structure = {}
|
structure = {}
|
||||||
for k, v in feats.items():
|
for k, v in feats.items():
|
||||||
degree = int(k)
|
degree = int(k)
|
||||||
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
||||||
assert v.shape[-1] == degree_to_dim(degree)
|
assert v.shape[-1] == degree_to_dim(degree)
|
||||||
structure[degree] = v.shape[-2]
|
structure[degree] = v.shape[-2]
|
||||||
return Fiber(structure)
|
return Fiber(structure)
|
||||||
|
|
||||||
def __getitem__(self, degree: int):
|
def __getitem__(self, degree: int):
|
||||||
""" fiber[degree] returns the multiplicity for this degree """
|
""" fiber[degree] returns the multiplicity for this degree """
|
||||||
return dict(self.structure).get(degree, 0)
|
return dict(self.structure).get(degree, 0)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
""" Iterate over namedtuples (degree, channels) """
|
""" Iterate over namedtuples (degree, channels) """
|
||||||
return iter(self.structure)
|
return iter(self.structure)
|
||||||
|
|
||||||
def __mul__(self, other):
|
def __mul__(self, other):
|
||||||
"""
|
"""
|
||||||
If other in an int, multiplies all the multiplicities by other.
|
If other in an int, multiplies all the multiplicities by other.
|
||||||
If other is a fiber, returns the cartesian product.
|
If other is a fiber, returns the cartesian product.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, Fiber):
|
if isinstance(other, Fiber):
|
||||||
return product(self.structure, other.structure)
|
return product(self.structure, other.structure)
|
||||||
elif isinstance(other, int):
|
elif isinstance(other, int):
|
||||||
return Fiber({t.degree: t.channels * other for t in self.structure})
|
return Fiber({t.degree: t.channels * other for t in self.structure})
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
"""
|
"""
|
||||||
If other in an int, add other to all the multiplicities.
|
If other in an int, add other to all the multiplicities.
|
||||||
If other is a fiber, add the multiplicities of the fibers together.
|
If other is a fiber, add the multiplicities of the fibers together.
|
||||||
"""
|
"""
|
||||||
if isinstance(other, Fiber):
|
if isinstance(other, Fiber):
|
||||||
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
||||||
elif isinstance(other, int):
|
elif isinstance(other, int):
|
||||||
return Fiber({t.degree: t.channels + other for t in self.structure})
|
return Fiber({t.degree: t.channels + other for t in self.structure})
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.structure)
|
return str(self.structure)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_max(f1, f2):
|
def combine_max(f1, f2):
|
||||||
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
||||||
new_dict = dict(f1.structure)
|
new_dict = dict(f1.structure)
|
||||||
for k, m in f2.structure:
|
for k, m in f2.structure:
|
||||||
new_dict[k] = max(new_dict.get(k, 0), m)
|
new_dict[k] = max(new_dict.get(k, 0), m)
|
||||||
|
|
||||||
return Fiber(list(new_dict.items()))
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_selectively(f1, f2):
|
def combine_selectively(f1, f2):
|
||||||
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
||||||
# only use orders which occur in fiber f1
|
# only use orders which occur in fiber f1
|
||||||
new_dict = dict(f1.structure)
|
new_dict = dict(f1.structure)
|
||||||
for k in f1.degrees:
|
for k in f1.degrees:
|
||||||
if k in f2.degrees:
|
if k in f2.degrees:
|
||||||
new_dict[k] += f2[k]
|
new_dict[k] += f2[k]
|
||||||
return Fiber(list(new_dict.items()))
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
||||||
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
||||||
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
||||||
self.degrees]
|
self.degrees]
|
||||||
fibers = torch.cat(fibers, -1)
|
fibers = torch.cat(fibers, -1)
|
||||||
return fibers
|
return fibers
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .linear import LinearSE3
|
from .linear import LinearSE3
|
||||||
from .norm import NormSE3
|
from .norm import NormSE3
|
||||||
from .pooling import GPooling
|
from .pooling import GPooling
|
||||||
from .convolution import ConvSE3
|
from .convolution import ConvSE3
|
||||||
from .attention import AttentionBlockSE3
|
from .attention import AttentionBlockSE3
|
|
@ -1,182 +1,180 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from dgl.ops import edge_softmax
|
from dgl.ops import edge_softmax
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
from se3_transformer.model.layers.linear import LinearSE3
|
from se3_transformer.model.layers.linear import LinearSE3
|
||||||
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
|
||||||
class AttentionSE3(nn.Module):
|
class AttentionSE3(nn.Module):
|
||||||
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
key_fiber: Fiber,
|
key_fiber: Fiber,
|
||||||
value_fiber: Fiber
|
value_fiber: Fiber
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param num_heads: Number of attention heads
|
:param num_heads: Number of attention heads
|
||||||
:param key_fiber: Fiber for the keys (and also for the queries)
|
:param key_fiber: Fiber for the keys (and also for the queries)
|
||||||
:param value_fiber: Fiber for the values
|
:param value_fiber: Fiber for the values
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.key_fiber = key_fiber
|
self.key_fiber = key_fiber
|
||||||
self.value_fiber = value_fiber
|
self.value_fiber = value_fiber
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
query: Dict[str, Tensor], # node features
|
query: Dict[str, Tensor], # node features
|
||||||
graph: DGLGraph
|
graph: DGLGraph
|
||||||
):
|
):
|
||||||
with nvtx_range('AttentionSE3'):
|
with nvtx_range('AttentionSE3'):
|
||||||
with nvtx_range('reshape keys and queries'):
|
with nvtx_range('reshape keys and queries'):
|
||||||
if isinstance(key, Tensor):
|
if isinstance(key, Tensor):
|
||||||
# case where features of all types are fused
|
# case where features of all types are fused
|
||||||
key = key.reshape(key.shape[0], self.num_heads, -1)
|
key = key.reshape(key.shape[0], self.num_heads, -1)
|
||||||
# need to reshape queries that way to keep the same layout as keys
|
# need to reshape queries that way to keep the same layout as keys
|
||||||
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
||||||
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
||||||
else:
|
else:
|
||||||
# features are not fused, need to fuse and reshape them
|
# features are not fused, need to fuse and reshape them
|
||||||
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
||||||
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
||||||
|
|
||||||
with nvtx_range('attention dot product + softmax'):
|
with nvtx_range('attention dot product + softmax'):
|
||||||
# Compute attention weights (softmax of inner product between key and query)
|
# Compute attention weights (softmax of inner product between key and query)
|
||||||
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
||||||
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
|
edge_weights = edge_weights / np.sqrt(self.key_fiber.num_features)
|
||||||
edge_weights = edge_softmax(graph, edge_weights)
|
edge_weights = edge_softmax(graph, edge_weights)
|
||||||
edge_weights = edge_weights[..., None, None]
|
edge_weights = edge_weights[..., None, None]
|
||||||
|
|
||||||
with nvtx_range('weighted sum'):
|
with nvtx_range('weighted sum'):
|
||||||
if isinstance(value, Tensor):
|
if isinstance(value, Tensor):
|
||||||
# features of all types are fused
|
# features of all types are fused
|
||||||
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
||||||
weights = edge_weights * v
|
weights = edge_weights * v
|
||||||
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
||||||
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
||||||
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
||||||
else:
|
else:
|
||||||
out = {}
|
out = {}
|
||||||
for degree, channels in self.value_fiber:
|
for degree, channels in self.value_fiber:
|
||||||
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
||||||
degree_to_dim(degree))
|
degree_to_dim(degree))
|
||||||
weights = edge_weights * v
|
weights = edge_weights * v
|
||||||
res = dgl.ops.copy_e_sum(graph, weights)
|
res = dgl.ops.copy_e_sum(graph, weights)
|
||||||
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlockSE3(nn.Module):
|
class AttentionBlockSE3(nn.Module):
|
||||||
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fiber_in: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_out: Fiber,
|
fiber_out: Fiber,
|
||||||
fiber_edge: Optional[Fiber] = None,
|
fiber_edge: Optional[Fiber] = None,
|
||||||
num_heads: int = 4,
|
num_heads: int = 4,
|
||||||
channels_div: int = 2,
|
channels_div: int = 2,
|
||||||
use_layer_norm: bool = False,
|
use_layer_norm: bool = False,
|
||||||
max_degree: bool = 4,
|
max_degree: bool = 4,
|
||||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param fiber_in: Fiber describing the input features
|
:param fiber_in: Fiber describing the input features
|
||||||
:param fiber_out: Fiber describing the output features
|
:param fiber_out: Fiber describing the output features
|
||||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
:param num_heads: Number of attention heads
|
:param num_heads: Number of attention heads
|
||||||
:param channels_div: Divide the channels by this integer for computing values
|
:param channels_div: Divide the channels by this integer for computing values
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param max_degree: Maximum degree used in the bases computation
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if fiber_edge is None:
|
if fiber_edge is None:
|
||||||
fiber_edge = Fiber({})
|
fiber_edge = Fiber({})
|
||||||
self.fiber_in = fiber_in
|
self.fiber_in = fiber_in
|
||||||
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
||||||
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
||||||
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
||||||
# (queries are merely projected, hence degrees have to match input)
|
# (queries are merely projected, hence degrees have to match input)
|
||||||
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
||||||
|
|
||||||
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
||||||
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
||||||
allow_fused_output=True)
|
allow_fused_output=True)
|
||||||
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
||||||
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
||||||
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
node_features: Dict[str, Tensor],
|
node_features: Dict[str, Tensor],
|
||||||
edge_features: Dict[str, Tensor],
|
edge_features: Dict[str, Tensor],
|
||||||
graph: DGLGraph,
|
graph: DGLGraph,
|
||||||
basis: Dict[str, Tensor]
|
basis: Dict[str, Tensor]
|
||||||
):
|
):
|
||||||
with nvtx_range('AttentionBlockSE3'):
|
with nvtx_range('AttentionBlockSE3'):
|
||||||
with nvtx_range('keys / values'):
|
with nvtx_range('keys / values'):
|
||||||
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
||||||
key, value = self._get_key_value_from_fused(fused_key_value)
|
key, value = self._get_key_value_from_fused(fused_key_value)
|
||||||
|
|
||||||
with nvtx_range('queries'):
|
with nvtx_range('queries'):
|
||||||
query = self.to_query(node_features)
|
query = self.to_query(node_features)
|
||||||
|
|
||||||
z = self.attention(value, key, query, graph)
|
z = self.attention(value, key, query, graph)
|
||||||
z_concat = aggregate_residual(node_features, z, 'cat')
|
z_concat = aggregate_residual(node_features, z, 'cat')
|
||||||
return self.project(z_concat)
|
return self.project(z_concat)
|
||||||
|
|
||||||
def _get_key_value_from_fused(self, fused_key_value):
|
def _get_key_value_from_fused(self, fused_key_value):
|
||||||
# Extract keys and queries features from fused features
|
# Extract keys and queries features from fused features
|
||||||
if isinstance(fused_key_value, Tensor):
|
if isinstance(fused_key_value, Tensor):
|
||||||
# Previous layer was a fully fused convolution
|
# Previous layer was a fully fused convolution
|
||||||
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
||||||
else:
|
else:
|
||||||
key, value = {}, {}
|
key, value = {}, {}
|
||||||
for degree, feat in fused_key_value.items():
|
for degree, feat in fused_key_value.items():
|
||||||
if int(degree) in self.fiber_in.degrees:
|
if int(degree) in self.fiber_in.degrees:
|
||||||
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
||||||
else:
|
else:
|
||||||
value[degree] = feat
|
value[degree] = feat
|
||||||
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,334 +1,345 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||||
|
|
||||||
|
|
||||||
class ConvSE3FuseLevel(Enum):
|
class ConvSE3FuseLevel(Enum):
|
||||||
"""
|
"""
|
||||||
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
||||||
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
||||||
A higher level means faster training, but also more memory usage.
|
A higher level means faster training, but also more memory usage.
|
||||||
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
||||||
If you want to train fast, choose a high value.
|
If you want to train fast, choose a high value.
|
||||||
Recommended value is FULL with AMP.
|
Recommended value is FULL with AMP.
|
||||||
|
|
||||||
Fully fused TFN convolutions requirements:
|
Fully fused TFN convolutions requirements:
|
||||||
- all input channels are the same
|
- all input channels are the same
|
||||||
- all output channels are the same
|
- all output channels are the same
|
||||||
- input degrees span the range [0, ..., max_degree]
|
- input degrees span the range [0, ..., max_degree]
|
||||||
- output degrees span the range [0, ..., max_degree]
|
- output degrees span the range [0, ..., max_degree]
|
||||||
|
|
||||||
Partially fused TFN convolutions requirements:
|
Partially fused TFN convolutions requirements:
|
||||||
* For fusing by output degree:
|
* For fusing by output degree:
|
||||||
- all input channels are the same
|
- all input channels are the same
|
||||||
- input degrees span the range [0, ..., max_degree]
|
- input degrees span the range [0, ..., max_degree]
|
||||||
* For fusing by input degree:
|
* For fusing by input degree:
|
||||||
- all output channels are the same
|
- all output channels are the same
|
||||||
- output degrees span the range [0, ..., max_degree]
|
- output degrees span the range [0, ..., max_degree]
|
||||||
|
|
||||||
Original TFN pairwise convolutions: no requirements
|
Original TFN pairwise convolutions: no requirements
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FULL = 2
|
FULL = 2
|
||||||
PARTIAL = 1
|
PARTIAL = 1
|
||||||
NONE = 0
|
NONE = 0
|
||||||
|
|
||||||
|
|
||||||
class RadialProfile(nn.Module):
|
class RadialProfile(nn.Module):
|
||||||
"""
|
"""
|
||||||
Radial profile function.
|
Radial profile function.
|
||||||
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
||||||
In TFN notation: $R^{l,k}$
|
In TFN notation: $R^{l,k}$
|
||||||
In SE(3)-Transformer notation: $\phi^{l,k}$
|
In SE(3)-Transformer notation: $\phi^{l,k}$
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
In the original papers, this function only depends on relative node distances ||x||.
|
In the original papers, this function only depends on relative node distances ||x||.
|
||||||
Here, we allow this function to also take as input additional invariant edge features.
|
Here, we allow this function to also take as input additional invariant edge features.
|
||||||
This does not break equivariance and adds expressive power to the model.
|
This does not break equivariance and adds expressive power to the model.
|
||||||
|
|
||||||
Diagram:
|
Diagram:
|
||||||
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_freq: int,
|
num_freq: int,
|
||||||
channels_in: int,
|
channels_in: int,
|
||||||
channels_out: int,
|
channels_out: int,
|
||||||
edge_dim: int = 1,
|
edge_dim: int = 1,
|
||||||
mid_dim: int = 32,
|
mid_dim: int = 32,
|
||||||
use_layer_norm: bool = False
|
use_layer_norm: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param num_freq: Number of frequencies
|
:param num_freq: Number of frequencies
|
||||||
:param channels_in: Number of input channels
|
:param channels_in: Number of input channels
|
||||||
:param channels_out: Number of output channels
|
:param channels_out: Number of output channels
|
||||||
:param edge_dim: Number of invariant edge features (input to the radial function)
|
:param edge_dim: Number of invariant edge features (input to the radial function)
|
||||||
:param mid_dim: Size of the hidden MLP layers
|
:param mid_dim: Size of the hidden MLP layers
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
modules = [
|
modules = [
|
||||||
nn.Linear(edge_dim, mid_dim),
|
nn.Linear(edge_dim, mid_dim),
|
||||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(mid_dim, mid_dim),
|
nn.Linear(mid_dim, mid_dim),
|
||||||
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
||||||
|
|
||||||
def forward(self, features: Tensor) -> Tensor:
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
return self.net(features)
|
return self.net(features)
|
||||||
|
|
||||||
|
|
||||||
class VersatileConvSE3(nn.Module):
|
class VersatileConvSE3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Building block for TFN convolutions.
|
Building block for TFN convolutions.
|
||||||
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
freq_sum: int,
|
freq_sum: int,
|
||||||
channels_in: int,
|
channels_in: int,
|
||||||
channels_out: int,
|
channels_out: int,
|
||||||
edge_dim: int,
|
edge_dim: int,
|
||||||
use_layer_norm: bool,
|
use_layer_norm: bool,
|
||||||
fuse_level: ConvSE3FuseLevel):
|
fuse_level: ConvSE3FuseLevel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.freq_sum = freq_sum
|
self.freq_sum = freq_sum
|
||||||
self.channels_out = channels_out
|
self.channels_out = channels_out
|
||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.fuse_level = fuse_level
|
self.fuse_level = fuse_level
|
||||||
self.radial_func = RadialProfile(num_freq=freq_sum,
|
self.radial_func = RadialProfile(num_freq=freq_sum,
|
||||||
channels_in=channels_in,
|
channels_in=channels_in,
|
||||||
channels_out=channels_out,
|
channels_out=channels_out,
|
||||||
edge_dim=edge_dim,
|
edge_dim=edge_dim,
|
||||||
use_layer_norm=use_layer_norm)
|
use_layer_norm=use_layer_norm)
|
||||||
|
|
||||||
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
||||||
with nvtx_range(f'VersatileConvSE3'):
|
with nvtx_range(f'VersatileConvSE3'):
|
||||||
num_edges = features.shape[0]
|
num_edges = features.shape[0]
|
||||||
in_dim = features.shape[2]
|
in_dim = features.shape[2]
|
||||||
with nvtx_range(f'RadialProfile'):
|
with nvtx_range(f'RadialProfile'):
|
||||||
radial_weights = self.radial_func(invariant_edge_feats) \
|
radial_weights = self.radial_func(invariant_edge_feats) \
|
||||||
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||||
|
|
||||||
if basis is not None:
|
if basis is not None:
|
||||||
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||||
out_dim = basis.shape[-1]
|
basis_view = basis.view(num_edges, in_dim, -1)
|
||||||
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
||||||
out_dim += out_dim % 2 - 1 # Account for padded basis
|
return radial_weights @ tmp
|
||||||
basis_view = basis.view(num_edges, in_dim, -1)
|
else:
|
||||||
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
# k = l = 0 non-fused case
|
||||||
return (radial_weights @ tmp)[:, :, :out_dim]
|
return radial_weights @ features
|
||||||
else:
|
|
||||||
# k = l = 0 non-fused case
|
|
||||||
return radial_weights @ features
|
class ConvSE3(nn.Module):
|
||||||
|
"""
|
||||||
|
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
||||||
class ConvSE3(nn.Module):
|
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
||||||
"""
|
Features of different degrees interact together to produce output features.
|
||||||
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
|
||||||
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
Note 1:
|
||||||
Features of different degrees interact together to produce output features.
|
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
||||||
|
done, and the returned features will be edge features instead of node features.
|
||||||
Note 1:
|
|
||||||
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
Note 2:
|
||||||
done, and the returned features will be edge features instead of node features.
|
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
||||||
|
Input edge features are concatenated with input source node features before the kernel is applied.
|
||||||
Note 2:
|
"""
|
||||||
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
|
||||||
Input edge features are concatenated with input source node features before the kernel is applied.
|
def __init__(
|
||||||
"""
|
self,
|
||||||
|
fiber_in: Fiber,
|
||||||
def __init__(
|
fiber_out: Fiber,
|
||||||
self,
|
fiber_edge: Fiber,
|
||||||
fiber_in: Fiber,
|
pool: bool = True,
|
||||||
fiber_out: Fiber,
|
use_layer_norm: bool = False,
|
||||||
fiber_edge: Fiber,
|
self_interaction: bool = False,
|
||||||
pool: bool = True,
|
max_degree: int = 4,
|
||||||
use_layer_norm: bool = False,
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
self_interaction: bool = False,
|
allow_fused_output: bool = False
|
||||||
max_degree: int = 4,
|
):
|
||||||
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
"""
|
||||||
allow_fused_output: bool = False
|
:param fiber_in: Fiber describing the input features
|
||||||
):
|
:param fiber_out: Fiber describing the output features
|
||||||
"""
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
:param fiber_in: Fiber describing the input features
|
:param pool: If True, compute final node features by averaging incoming edge features
|
||||||
:param fiber_out: Fiber describing the output features
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
:param self_interaction: Apply self-interaction of nodes
|
||||||
:param pool: If True, compute final node features by averaging incoming edge features
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
:param self_interaction: Apply self-interaction of nodes
|
:param allow_fused_output: Allow the module to output a fused representation of features
|
||||||
:param max_degree: Maximum degree used in the bases computation
|
"""
|
||||||
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
super().__init__()
|
||||||
:param allow_fused_output: Allow the module to output a fused representation of features
|
self.pool = pool
|
||||||
"""
|
self.fiber_in = fiber_in
|
||||||
super().__init__()
|
self.fiber_out = fiber_out
|
||||||
self.pool = pool
|
self.self_interaction = self_interaction
|
||||||
self.fiber_in = fiber_in
|
self.max_degree = max_degree
|
||||||
self.fiber_out = fiber_out
|
self.allow_fused_output = allow_fused_output
|
||||||
self.self_interaction = self_interaction
|
|
||||||
self.max_degree = max_degree
|
# channels_in: account for the concatenation of edge features
|
||||||
self.allow_fused_output = allow_fused_output
|
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
||||||
|
channels_out_set = set([f.channels for f in self.fiber_out])
|
||||||
# channels_in: account for the concatenation of edge features
|
unique_channels_in = (len(channels_in_set) == 1)
|
||||||
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
unique_channels_out = (len(channels_out_set) == 1)
|
||||||
channels_out_set = set([f.channels for f in self.fiber_out])
|
degrees_up_to_max = list(range(max_degree + 1))
|
||||||
unique_channels_in = (len(channels_in_set) == 1)
|
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
||||||
unique_channels_out = (len(channels_out_set) == 1)
|
|
||||||
degrees_up_to_max = list(range(max_degree + 1))
|
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
||||||
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
||||||
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
# Single fused convolution
|
||||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
||||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
|
||||||
# Single fused convolution
|
sum_freq = sum([
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
degree_to_dim(min(d_in, d_out))
|
||||||
|
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
||||||
sum_freq = sum([
|
])
|
||||||
degree_to_dim(min(d_in, d_out))
|
|
||||||
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
||||||
])
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
|
|
||||||
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
||||||
|
# Convolutions fused per output degree
|
||||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
self.conv_out = nn.ModuleDict()
|
||||||
# Convolutions fused per output degree
|
for d_out, c_out in fiber_out:
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
||||||
self.conv_out = nn.ModuleDict()
|
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
||||||
for d_out, c_out in fiber_out:
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
|
||||||
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
|
# Convolutions fused per input degree
|
||||||
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
self.conv_in = nn.ModuleDict()
|
||||||
# Convolutions fused per input degree
|
for d_in, c_in in fiber_in:
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||||
self.conv_in = nn.ModuleDict()
|
channels_in_new = c_in + fiber_edge[d_in] * (d_in > 0)
|
||||||
for d_in, c_in in fiber_in:
|
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, channels_in_new, list(channels_out_set)[0],
|
||||||
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
else:
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
# Use pairwise TFN convolutions
|
||||||
else:
|
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
||||||
# Use pairwise TFN convolutions
|
self.conv = nn.ModuleDict()
|
||||||
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
||||||
self.conv = nn.ModuleDict()
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
||||||
dict_key = f'{degree_in},{degree_out}'
|
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
||||||
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
||||||
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
|
||||||
fuse_level=self.used_fuse_level, **common_args)
|
if self_interaction:
|
||||||
|
self.to_kernel_self = nn.ParameterDict()
|
||||||
if self_interaction:
|
for degree_out, channels_out in fiber_out:
|
||||||
self.to_kernel_self = nn.ParameterDict()
|
if fiber_in[degree_out]:
|
||||||
for degree_out, channels_out in fiber_out:
|
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
||||||
if fiber_in[degree_out]:
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
|
||||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
def _try_unpad(self, feature, basis):
|
||||||
|
# Account for padded basis
|
||||||
def forward(
|
if basis is not None:
|
||||||
self,
|
out_dim = basis.shape[-1]
|
||||||
node_feats: Dict[str, Tensor],
|
out_dim += out_dim % 2 - 1
|
||||||
edge_feats: Dict[str, Tensor],
|
return feature[..., :out_dim]
|
||||||
graph: DGLGraph,
|
else:
|
||||||
basis: Dict[str, Tensor]
|
return feature
|
||||||
):
|
|
||||||
with nvtx_range(f'ConvSE3'):
|
def forward(
|
||||||
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
self,
|
||||||
src, dst = graph.edges()
|
node_feats: Dict[str, Tensor],
|
||||||
out = {}
|
edge_feats: Dict[str, Tensor],
|
||||||
in_features = []
|
graph: DGLGraph,
|
||||||
|
basis: Dict[str, Tensor]
|
||||||
# Fetch all input features from edge and node features
|
):
|
||||||
for degree_in in self.fiber_in.degrees:
|
with nvtx_range(f'ConvSE3'):
|
||||||
src_node_features = node_feats[str(degree_in)][src]
|
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
||||||
if degree_in > 0 and str(degree_in) in edge_feats:
|
src, dst = graph.edges()
|
||||||
# Handle edge features of any type by concatenating them to node features
|
out = {}
|
||||||
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
in_features = []
|
||||||
in_features.append(src_node_features)
|
|
||||||
|
# Fetch all input features from edge and node features
|
||||||
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
for degree_in in self.fiber_in.degrees:
|
||||||
in_features_fused = torch.cat(in_features, dim=-1)
|
src_node_features = node_feats[str(degree_in)][src]
|
||||||
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
if degree_in > 0 and str(degree_in) in edge_feats:
|
||||||
|
# Handle edge features of any type by concatenating them to node features
|
||||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
||||||
out = unfuse_features(out, self.fiber_out.degrees)
|
in_features.append(src_node_features)
|
||||||
|
|
||||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
||||||
in_features_fused = torch.cat(in_features, dim=-1)
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
for degree_out in self.fiber_out.degrees:
|
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
||||||
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis[f'out{degree_out}_fused'])
|
|
||||||
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
out = 0
|
|
||||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
||||||
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
if not self.allow_fused_output or self.self_interaction or self.pool:
|
for degree_out in self.fiber_out.degrees:
|
||||||
out = unfuse_features(out, self.fiber_out.degrees)
|
basis_used = basis[f'out{degree_out}_fused']
|
||||||
else:
|
out[str(degree_out)] = self._try_unpad(
|
||||||
# Fallback to pairwise TFN convolutions
|
self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats, basis_used),
|
||||||
for degree_out in self.fiber_out.degrees:
|
basis_used)
|
||||||
out_feature = 0
|
|
||||||
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||||
dict_key = f'{degree_in},{degree_out}'
|
out = 0
|
||||||
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats, basis.get(dict_key, None))
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
out[str(degree_out)] = out_feature
|
out = out + self.conv_in[str(degree_in)](feature, invariant_edge_feats, basis[f'in{degree_in}_fused'])
|
||||||
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
for degree_out in self.fiber_out.degrees:
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
else:
|
||||||
with nvtx_range(f'self interaction'):
|
# Fallback to pairwise TFN convolutions
|
||||||
dst_features = node_feats[str(degree_out)][dst]
|
for degree_out in self.fiber_out.degrees:
|
||||||
kernel_self = self.to_kernel_self[str(degree_out)]
|
out_feature = 0
|
||||||
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
if self.pool:
|
basis_used = basis.get(dict_key, None)
|
||||||
with nvtx_range(f'pooling'):
|
out_feature = out_feature + self._try_unpad(
|
||||||
if isinstance(out, dict):
|
self.conv[dict_key](feature, invariant_edge_feats, basis_used),
|
||||||
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
basis_used)
|
||||||
else:
|
out[str(degree_out)] = out_feature
|
||||||
out = dgl.ops.copy_e_sum(graph, out)
|
|
||||||
return out
|
for degree_out in self.fiber_out.degrees:
|
||||||
|
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
||||||
|
with nvtx_range(f'self interaction'):
|
||||||
|
dst_features = node_feats[str(degree_out)][dst]
|
||||||
|
kernel_self = self.to_kernel_self[str(degree_out)]
|
||||||
|
out[str(degree_out)] = out[str(degree_out)] + kernel_self @ dst_features
|
||||||
|
|
||||||
|
if self.pool:
|
||||||
|
with nvtx_range(f'pooling'):
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
||||||
|
else:
|
||||||
|
out = dgl.ops.copy_e_sum(graph, out)
|
||||||
|
return out
|
||||||
|
|
|
@ -1,59 +1,59 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class LinearSE3(nn.Module):
|
class LinearSE3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
||||||
Maps a fiber to a fiber with the same degrees (channels may be different).
|
Maps a fiber to a fiber with the same degrees (channels may be different).
|
||||||
No interaction between degrees, but interaction between channels.
|
No interaction between degrees, but interaction between channels.
|
||||||
|
|
||||||
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
||||||
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
||||||
:
|
:
|
||||||
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weights = nn.ParameterDict({
|
self.weights = nn.ParameterDict({
|
||||||
str(degree_out): nn.Parameter(
|
str(degree_out): nn.Parameter(
|
||||||
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
for degree_out, channels_out in fiber_out
|
for degree_out, channels_out in fiber_out
|
||||||
})
|
})
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
return {
|
return {
|
||||||
degree: self.weights[degree] @ features[degree]
|
degree: self.weights[degree] @ features[degree]
|
||||||
for degree, weight in self.weights.items()
|
for degree, weight in self.weights.items()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,83 +1,83 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.nvtx import range as nvtx_range
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class NormSE3(nn.Module):
|
class NormSE3(nn.Module):
|
||||||
"""
|
"""
|
||||||
Norm-based SE(3)-equivariant nonlinearity.
|
Norm-based SE(3)-equivariant nonlinearity.
|
||||||
|
|
||||||
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
||||||
feature_in ──┤ * ──> feature_out
|
feature_in ──┤ * ──> feature_out
|
||||||
└──> feature_phase ────────────────────────────┘
|
└──> feature_phase ────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||||
|
|
||||||
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fiber = fiber
|
self.fiber = fiber
|
||||||
self.nonlinearity = nonlinearity
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
if len(set(fiber.channels)) == 1:
|
if len(set(fiber.channels)) == 1:
|
||||||
# Fuse all the layer normalizations into a group normalization
|
# Fuse all the layer normalizations into a group normalization
|
||||||
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
||||||
else:
|
else:
|
||||||
# Use multiple layer normalizations
|
# Use multiple layer normalizations
|
||||||
self.layer_norms = nn.ModuleDict({
|
self.layer_norms = nn.ModuleDict({
|
||||||
str(degree): nn.LayerNorm(channels)
|
str(degree): nn.LayerNorm(channels)
|
||||||
for degree, channels in fiber
|
for degree, channels in fiber
|
||||||
})
|
})
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
with nvtx_range('NormSE3'):
|
with nvtx_range('NormSE3'):
|
||||||
output = {}
|
output = {}
|
||||||
if hasattr(self, 'group_norm'):
|
if hasattr(self, 'group_norm'):
|
||||||
# Compute per-degree norms of features
|
# Compute per-degree norms of features
|
||||||
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
for d in self.fiber.degrees]
|
for d in self.fiber.degrees]
|
||||||
fused_norms = torch.cat(norms, dim=-2)
|
fused_norms = torch.cat(norms, dim=-2)
|
||||||
|
|
||||||
# Transform the norms only
|
# Transform the norms only
|
||||||
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
||||||
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
||||||
|
|
||||||
# Scale features to the new norms
|
# Scale features to the new norms
|
||||||
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
||||||
output[str(d)] = features[str(d)] / norm * new_norm
|
output[str(d)] = features[str(d)] / norm * new_norm
|
||||||
else:
|
else:
|
||||||
for degree, feat in features.items():
|
for degree, feat in features.items():
|
||||||
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
||||||
output[degree] = new_norm * feat / norm
|
output[degree] = new_norm * feat / norm
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -1,53 +1,53 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from typing import Dict, Literal
|
from typing import Dict, Literal
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
class GPooling(nn.Module):
|
class GPooling(nn.Module):
|
||||||
"""
|
"""
|
||||||
Graph max/average pooling on a given feature type.
|
Graph max/average pooling on a given feature type.
|
||||||
The average can be taken for any feature type, and equivariance will be maintained.
|
The average can be taken for any feature type, and equivariance will be maintained.
|
||||||
The maximum can only be taken for invariant features (type 0).
|
The maximum can only be taken for invariant features (type 0).
|
||||||
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
||||||
"""
|
"""
|
||||||
:param feat_type: Feature type to pool
|
:param feat_type: Feature type to pool
|
||||||
:param pool: Type of pooling: max or avg
|
:param pool: Type of pooling: max or avg
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
||||||
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
||||||
self.feat_type = feat_type
|
self.feat_type = feat_type
|
||||||
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
||||||
|
|
||||||
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
||||||
pooled = self.pool(graph, features[str(self.feat_type)])
|
pooled = self.pool(graph, features[str(self.feat_type)])
|
||||||
return pooled.squeeze(dim=-1)
|
return pooled.squeeze(dim=-1)
|
||||||
|
|
|
@ -1,222 +1,222 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Literal, Dict
|
from typing import Optional, Literal, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dgl import DGLGraph
|
from dgl import DGLGraph
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||||
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||||
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
from se3_transformer.model.layers.norm import NormSE3
|
from se3_transformer.model.layers.norm import NormSE3
|
||||||
from se3_transformer.model.layers.pooling import GPooling
|
from se3_transformer.model.layers.pooling import GPooling
|
||||||
from se3_transformer.runtime.utils import str2bool
|
from se3_transformer.runtime.utils import str2bool
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
class Sequential(nn.Sequential):
|
class Sequential(nn.Sequential):
|
||||||
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
||||||
|
|
||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
for module in self:
|
for module in self:
|
||||||
input = module(input, *args, **kwargs)
|
input = module(input, *args, **kwargs)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
||||||
""" Add relative positions to existing edge features """
|
""" Add relative positions to existing edge features """
|
||||||
edge_features = edge_features.copy() if edge_features else {}
|
edge_features = edge_features.copy() if edge_features else {}
|
||||||
r = relative_pos.norm(dim=-1, keepdim=True)
|
r = relative_pos.norm(dim=-1, keepdim=True)
|
||||||
if '0' in edge_features:
|
if '0' in edge_features:
|
||||||
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
||||||
else:
|
else:
|
||||||
edge_features['0'] = r[..., None]
|
edge_features['0'] = r[..., None]
|
||||||
|
|
||||||
return edge_features
|
return edge_features
|
||||||
|
|
||||||
|
|
||||||
class SE3Transformer(nn.Module):
|
class SE3Transformer(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
fiber_in: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_hidden: Fiber,
|
fiber_hidden: Fiber,
|
||||||
fiber_out: Fiber,
|
fiber_out: Fiber,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
channels_div: int,
|
channels_div: int,
|
||||||
fiber_edge: Fiber = Fiber({}),
|
fiber_edge: Fiber = Fiber({}),
|
||||||
return_type: Optional[int] = None,
|
return_type: Optional[int] = None,
|
||||||
pooling: Optional[Literal['avg', 'max']] = None,
|
pooling: Optional[Literal['avg', 'max']] = None,
|
||||||
norm: bool = True,
|
norm: bool = True,
|
||||||
use_layer_norm: bool = True,
|
use_layer_norm: bool = True,
|
||||||
tensor_cores: bool = False,
|
tensor_cores: bool = False,
|
||||||
low_memory: bool = False,
|
low_memory: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
:param num_layers: Number of attention layers
|
:param num_layers: Number of attention layers
|
||||||
:param fiber_in: Input fiber description
|
:param fiber_in: Input fiber description
|
||||||
:param fiber_hidden: Hidden fiber description
|
:param fiber_hidden: Hidden fiber description
|
||||||
:param fiber_out: Output fiber description
|
:param fiber_out: Output fiber description
|
||||||
:param fiber_edge: Input edge fiber description
|
:param fiber_edge: Input edge fiber description
|
||||||
:param num_heads: Number of attention heads
|
:param num_heads: Number of attention heads
|
||||||
:param channels_div: Channels division before feeding to attention layer
|
:param channels_div: Channels division before feeding to attention layer
|
||||||
:param return_type: Return only features of this type
|
:param return_type: Return only features of this type
|
||||||
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
||||||
:param norm: Apply a normalization layer after each attention block
|
:param norm: Apply a normalization layer after each attention block
|
||||||
:param use_layer_norm: Apply layer normalization between MLP layers
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
||||||
:param low_memory: If True, will use slower ops that use less memory
|
:param low_memory: If True, will use slower ops that use less memory
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.fiber_edge = fiber_edge
|
self.fiber_edge = fiber_edge
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.channels_div = channels_div
|
self.channels_div = channels_div
|
||||||
self.return_type = return_type
|
self.return_type = return_type
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
||||||
self.tensor_cores = tensor_cores
|
self.tensor_cores = tensor_cores
|
||||||
self.low_memory = low_memory
|
self.low_memory = low_memory
|
||||||
|
|
||||||
if low_memory and not tensor_cores:
|
if low_memory and not tensor_cores:
|
||||||
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
||||||
|
|
||||||
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
||||||
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
||||||
|
|
||||||
graph_modules = []
|
graph_modules = []
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||||
fiber_out=fiber_hidden,
|
fiber_out=fiber_hidden,
|
||||||
fiber_edge=fiber_edge,
|
fiber_edge=fiber_edge,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
channels_div=channels_div,
|
channels_div=channels_div,
|
||||||
use_layer_norm=use_layer_norm,
|
use_layer_norm=use_layer_norm,
|
||||||
max_degree=self.max_degree,
|
max_degree=self.max_degree,
|
||||||
fuse_level=fuse_level))
|
fuse_level=fuse_level))
|
||||||
if norm:
|
if norm:
|
||||||
graph_modules.append(NormSE3(fiber_hidden))
|
graph_modules.append(NormSE3(fiber_hidden))
|
||||||
fiber_in = fiber_hidden
|
fiber_in = fiber_hidden
|
||||||
|
|
||||||
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
||||||
fiber_out=fiber_out,
|
fiber_out=fiber_out,
|
||||||
fiber_edge=fiber_edge,
|
fiber_edge=fiber_edge,
|
||||||
self_interaction=True,
|
self_interaction=True,
|
||||||
use_layer_norm=use_layer_norm,
|
use_layer_norm=use_layer_norm,
|
||||||
max_degree=self.max_degree))
|
max_degree=self.max_degree))
|
||||||
self.graph_modules = Sequential(*graph_modules)
|
self.graph_modules = Sequential(*graph_modules)
|
||||||
|
|
||||||
if pooling is not None:
|
if pooling is not None:
|
||||||
assert return_type is not None, 'return_type must be specified when pooling'
|
assert return_type is not None, 'return_type must be specified when pooling'
|
||||||
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
||||||
|
|
||||||
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
||||||
edge_feats: Optional[Dict[str, Tensor]] = None,
|
edge_feats: Optional[Dict[str, Tensor]] = None,
|
||||||
basis: Optional[Dict[str, Tensor]] = None):
|
basis: Optional[Dict[str, Tensor]] = None):
|
||||||
# Compute bases in case they weren't precomputed as part of the data loading
|
# Compute bases in case they weren't precomputed as part of the data loading
|
||||||
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
||||||
use_pad_trick=self.tensor_cores and not self.low_memory,
|
use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
amp=torch.is_autocast_enabled())
|
amp=torch.is_autocast_enabled())
|
||||||
|
|
||||||
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
||||||
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
fully_fused=self.tensor_cores and not self.low_memory)
|
fully_fused=self.tensor_cores and not self.low_memory)
|
||||||
|
|
||||||
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
||||||
|
|
||||||
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
||||||
|
|
||||||
if self.pooling is not None:
|
if self.pooling is not None:
|
||||||
return self.pooling_module(node_feats, graph=graph)
|
return self.pooling_module(node_feats, graph=graph)
|
||||||
|
|
||||||
if self.return_type is not None:
|
if self.return_type is not None:
|
||||||
return node_feats[str(self.return_type)]
|
return node_feats[str(self.return_type)]
|
||||||
|
|
||||||
return node_feats
|
return node_feats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_argparse_args(parser):
|
def add_argparse_args(parser):
|
||||||
parser.add_argument('--num_layers', type=int, default=7,
|
parser.add_argument('--num_layers', type=int, default=7,
|
||||||
help='Number of stacked Transformer layers')
|
help='Number of stacked Transformer layers')
|
||||||
parser.add_argument('--num_heads', type=int, default=8,
|
parser.add_argument('--num_heads', type=int, default=8,
|
||||||
help='Number of heads in self-attention')
|
help='Number of heads in self-attention')
|
||||||
parser.add_argument('--channels_div', type=int, default=2,
|
parser.add_argument('--channels_div', type=int, default=2,
|
||||||
help='Channels division before feeding to attention layer')
|
help='Channels division before feeding to attention layer')
|
||||||
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
||||||
help='Type of graph pooling')
|
help='Type of graph pooling')
|
||||||
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='Apply a normalization layer after each attention block')
|
help='Apply a normalization layer after each attention block')
|
||||||
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='Apply layer normalization between MLP layers')
|
help='Apply layer normalization between MLP layers')
|
||||||
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='If true, will use fused ops that are slower but that use less memory '
|
help='If true, will use fused ops that are slower but that use less memory '
|
||||||
'(expect 25 percent less memory). '
|
'(expect 25 percent less memory). '
|
||||||
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class SE3TransformerPooled(nn.Module):
|
class SE3TransformerPooled(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
fiber_in: Fiber,
|
fiber_in: Fiber,
|
||||||
fiber_out: Fiber,
|
fiber_out: Fiber,
|
||||||
fiber_edge: Fiber,
|
fiber_edge: Fiber,
|
||||||
num_degrees: int,
|
num_degrees: int,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
||||||
self.transformer = SE3Transformer(
|
self.transformer = SE3Transformer(
|
||||||
fiber_in=fiber_in,
|
fiber_in=fiber_in,
|
||||||
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
||||||
fiber_out=fiber_out,
|
fiber_out=fiber_out,
|
||||||
fiber_edge=fiber_edge,
|
fiber_edge=fiber_edge,
|
||||||
return_type=0,
|
return_type=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
n_out_features = fiber_out.num_features
|
n_out_features = fiber_out.num_features
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Linear(n_out_features, n_out_features),
|
nn.Linear(n_out_features, n_out_features),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(n_out_features, output_dim)
|
nn.Linear(n_out_features, output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, graph, node_feats, edge_feats, basis=None):
|
def forward(self, graph, node_feats, edge_feats, basis=None):
|
||||||
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
||||||
y = self.mlp(feats).squeeze(-1)
|
y = self.mlp(feats).squeeze(-1)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_argparse_args(parent_parser):
|
def add_argparse_args(parent_parser):
|
||||||
parser = parent_parser.add_argument_group("Model architecture")
|
parser = parent_parser.add_argument_group("Model architecture")
|
||||||
SE3Transformer.add_argparse_args(parser)
|
SE3Transformer.add_argparse_args(parser)
|
||||||
parser.add_argument('--num_degrees',
|
parser.add_argument('--num_degrees',
|
||||||
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
||||||
type=int, default=4)
|
type=int, default=4)
|
||||||
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
||||||
return parent_parser
|
return parent_parser
|
||||||
|
|
|
@ -1,70 +1,70 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
from se3_transformer.data_loading import QM9DataModule
|
from se3_transformer.data_loading import QM9DataModule
|
||||||
from se3_transformer.model import SE3TransformerPooled
|
from se3_transformer.model import SE3TransformerPooled
|
||||||
from se3_transformer.runtime.utils import str2bool
|
from se3_transformer.runtime.utils import str2bool
|
||||||
|
|
||||||
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
||||||
|
|
||||||
paths = PARSER.add_argument_group('Paths')
|
paths = PARSER.add_argument_group('Paths')
|
||||||
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
||||||
help='Directory where the data is located or should be downloaded')
|
help='Directory where the data is located or should be downloaded')
|
||||||
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
||||||
help='Directory where the results logs should be saved')
|
help='Directory where the results logs should be saved')
|
||||||
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
||||||
help='Name for the resulting DLLogger JSON file')
|
help='Name for the resulting DLLogger JSON file')
|
||||||
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
||||||
help='File where the checkpoint should be saved')
|
help='File where the checkpoint should be saved')
|
||||||
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
||||||
help='File of the checkpoint to be loaded')
|
help='File of the checkpoint to be loaded')
|
||||||
|
|
||||||
optimizer = PARSER.add_argument_group('Optimizer')
|
optimizer = PARSER.add_argument_group('Optimizer')
|
||||||
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
||||||
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
||||||
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
||||||
optimizer.add_argument('--momentum', type=float, default=0.9)
|
optimizer.add_argument('--momentum', type=float, default=0.9)
|
||||||
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
||||||
|
|
||||||
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
||||||
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
||||||
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
||||||
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
||||||
|
|
||||||
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
||||||
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
||||||
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
||||||
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
||||||
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=20,
|
||||||
help='Do an evaluation round every N epochs')
|
help='Do an evaluation round every N epochs')
|
||||||
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='Minimize stdout output')
|
help='Minimize stdout output')
|
||||||
|
|
||||||
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
||||||
help='Benchmark mode')
|
help='Benchmark mode')
|
||||||
|
|
||||||
QM9DataModule.add_argparse_args(PARSER)
|
QM9DataModule.add_argparse_args(PARSER)
|
||||||
SE3TransformerPooled.add_argparse_args(PARSER)
|
SE3TransformerPooled.add_argparse_args(PARSER)
|
||||||
|
|
|
@ -1,240 +1,240 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from apex.optimizers import FusedAdam, FusedLAMB
|
from apex.optimizers import FusedAdam, FusedLAMB
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from se3_transformer.data_loading import QM9DataModule
|
from se3_transformer.data_loading import QM9DataModule
|
||||||
from se3_transformer.model import SE3TransformerPooled
|
from se3_transformer.model import SE3TransformerPooled
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from se3_transformer.runtime import gpu_affinity
|
from se3_transformer.runtime import gpu_affinity
|
||||||
from se3_transformer.runtime.arguments import PARSER
|
from se3_transformer.runtime.arguments import PARSER
|
||||||
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
||||||
PerformanceCallback
|
PerformanceCallback
|
||||||
from se3_transformer.runtime.inference import evaluate
|
from se3_transformer.runtime.inference import evaluate
|
||||||
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
||||||
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
||||||
using_tensor_cores, increase_l2_fetch_granularity
|
using_tensor_cores, increase_l2_fetch_granularity
|
||||||
|
|
||||||
|
|
||||||
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
""" Saves model, optimizer and epoch states to path (only once per node) """
|
""" Saves model, optimizer and epoch states to path (only once per node) """
|
||||||
if get_local_rank() == 0:
|
if get_local_rank() == 0:
|
||||||
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'state_dict': state_dict,
|
'state_dict': state_dict,
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
'epoch': epoch
|
'epoch': epoch
|
||||||
}
|
}
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_checkpoint_save(checkpoint)
|
callback.on_checkpoint_save(checkpoint)
|
||||||
|
|
||||||
torch.save(checkpoint, str(path))
|
torch.save(checkpoint, str(path))
|
||||||
logging.info(f'Saved checkpoint to {str(path)}')
|
logging.info(f'Saved checkpoint to {str(path)}')
|
||||||
|
|
||||||
|
|
||||||
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
""" Loads model, optimizer and epoch states from path """
|
""" Loads model, optimizer and epoch states from path """
|
||||||
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
||||||
if isinstance(model, DistributedDataParallel):
|
if isinstance(model, DistributedDataParallel):
|
||||||
model.module.load_state_dict(checkpoint['state_dict'])
|
model.module.load_state_dict(checkpoint['state_dict'])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_checkpoint_load(checkpoint)
|
callback.on_checkpoint_load(checkpoint)
|
||||||
|
|
||||||
logging.info(f'Loaded checkpoint from {str(path)}')
|
logging.info(f'Loaded checkpoint from {str(path)}')
|
||||||
return checkpoint['epoch']
|
return checkpoint['epoch']
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||||
losses = []
|
losses = []
|
||||||
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
||||||
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
||||||
*inputs, target = to_cuda(batch)
|
*inputs, target = to_cuda(batch)
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_batch_start()
|
callback.on_batch_start()
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
pred = model(*inputs)
|
pred = model(*inputs)
|
||||||
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
||||||
|
|
||||||
grad_scaler.scale(loss).backward()
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
# gradient accumulation
|
# gradient accumulation
|
||||||
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
||||||
if args.gradient_clip:
|
if args.gradient_clip:
|
||||||
grad_scaler.unscale_(optimizer)
|
grad_scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
||||||
|
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
return np.mean(losses)
|
return np.mean(losses)
|
||||||
|
|
||||||
|
|
||||||
def train(model: nn.Module,
|
def train(model: nn.Module,
|
||||||
loss_fn: _Loss,
|
loss_fn: _Loss,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
val_dataloader: DataLoader,
|
val_dataloader: DataLoader,
|
||||||
callbacks: List[BaseCallback],
|
callbacks: List[BaseCallback],
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
args):
|
args):
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
model.to(device=device)
|
model.to(device=device)
|
||||||
local_rank = get_local_rank()
|
local_rank = get_local_rank()
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
if args.optimizer == 'adam':
|
if args.optimizer == 'adam':
|
||||||
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
elif args.optimizer == 'lamb':
|
elif args.optimizer == 'lamb':
|
||||||
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
else:
|
else:
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_fit_start(optimizer, args)
|
callback.on_fit_start(optimizer, args)
|
||||||
|
|
||||||
for epoch_idx in range(epoch_start, args.epochs):
|
for epoch_idx in range(epoch_start, args.epochs):
|
||||||
if isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch_idx)
|
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||||
|
|
||||||
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||||
torch.distributed.all_reduce(loss)
|
torch.distributed.all_reduce(loss)
|
||||||
loss = (loss / world_size).item()
|
loss = (loss / world_size).item()
|
||||||
|
|
||||||
logging.info(f'Train loss: {loss}')
|
logging.info(f'Train loss: {loss}')
|
||||||
logger.log_metrics({'train loss': loss}, epoch_idx)
|
logger.log_metrics({'train loss': loss}, epoch_idx)
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_epoch_end()
|
callback.on_epoch_end()
|
||||||
|
|
||||||
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
||||||
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
||||||
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
||||||
|
|
||||||
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
|
||||||
evaluate(model, val_dataloader, callbacks, args)
|
evaluate(model, val_dataloader, callbacks, args)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_validation_end(epoch_idx)
|
callback.on_validation_end(epoch_idx)
|
||||||
|
|
||||||
if args.save_ckpt_path is not None and not args.benchmark:
|
if args.save_ckpt_path is not None and not args.benchmark:
|
||||||
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
||||||
|
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_fit_end()
|
callback.on_fit_end()
|
||||||
|
|
||||||
|
|
||||||
def print_parameters_count(model):
|
def print_parameters_count(model):
|
||||||
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
is_distributed = init_distributed()
|
is_distributed = init_distributed()
|
||||||
local_rank = get_local_rank()
|
local_rank = get_local_rank()
|
||||||
args = PARSER.parse_args()
|
args = PARSER.parse_args()
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
||||||
|
|
||||||
logging.info('====== SE(3)-Transformer ======')
|
logging.info('====== SE(3)-Transformer ======')
|
||||||
logging.info('| Training procedure |')
|
logging.info('| Training procedure |')
|
||||||
logging.info('===============================')
|
logging.info('===============================')
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
logging.info(f'Using seed {args.seed}')
|
logging.info(f'Using seed {args.seed}')
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
|
|
||||||
logger = LoggerCollection([
|
logger = LoggerCollection([
|
||||||
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
||||||
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
||||||
])
|
])
|
||||||
|
|
||||||
datamodule = QM9DataModule(**vars(args))
|
datamodule = QM9DataModule(**vars(args))
|
||||||
model = SE3TransformerPooled(
|
model = SE3TransformerPooled(
|
||||||
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||||
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
||||||
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
||||||
**vars(args)
|
**vars(args)
|
||||||
)
|
)
|
||||||
loss_fn = nn.L1Loss()
|
loss_fn = nn.L1Loss()
|
||||||
|
|
||||||
if args.benchmark:
|
if args.benchmark:
|
||||||
logging.info('Running benchmark mode')
|
logging.info('Running benchmark mode')
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
||||||
else:
|
else:
|
||||||
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
||||||
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
||||||
|
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
||||||
|
|
||||||
print_parameters_count(model)
|
print_parameters_count(model)
|
||||||
logger.log_hyperparams(vars(args))
|
logger.log_hyperparams(vars(args))
|
||||||
increase_l2_fetch_granularity()
|
increase_l2_fetch_granularity()
|
||||||
train(model,
|
train(model,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
datamodule.train_dataloader(),
|
datamodule.train_dataloader(),
|
||||||
datamodule.val_dataloader(),
|
datamodule.val_dataloader(),
|
||||||
callbacks,
|
callbacks,
|
||||||
logger,
|
logger,
|
||||||
args)
|
args)
|
||||||
|
|
||||||
logging.info('Training finished successfully')
|
logging.info('Training finished successfully')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,102 +1,102 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from se3_transformer.model import SE3Transformer
|
from se3_transformer.model import SE3Transformer
|
||||||
from se3_transformer.model.fiber import Fiber
|
from se3_transformer.model.fiber import Fiber
|
||||||
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||||
|
|
||||||
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
||||||
TOL = 1e-3
|
TOL = 1e-3
|
||||||
CHANNELS, NODES = 32, 512
|
CHANNELS, NODES = 32, 512
|
||||||
|
|
||||||
|
|
||||||
def _get_outputs(model, R):
|
def _get_outputs(model, R):
|
||||||
feats0 = torch.randn(NODES, CHANNELS, 1)
|
feats0 = torch.randn(NODES, CHANNELS, 1)
|
||||||
feats1 = torch.randn(NODES, CHANNELS, 3)
|
feats1 = torch.randn(NODES, CHANNELS, 3)
|
||||||
|
|
||||||
coords = torch.randn(NODES, 3)
|
coords = torch.randn(NODES, 3)
|
||||||
graph = get_random_graph(NODES)
|
graph = get_random_graph(NODES)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
feats0 = feats0.cuda()
|
feats0 = feats0.cuda()
|
||||||
feats1 = feats1.cuda()
|
feats1 = feats1.cuda()
|
||||||
R = R.cuda()
|
R = R.cuda()
|
||||||
coords = coords.cuda()
|
coords = coords.cuda()
|
||||||
graph = graph.to('cuda')
|
graph = graph.to('cuda')
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
graph1 = assign_relative_pos(graph, coords)
|
graph1 = assign_relative_pos(graph, coords)
|
||||||
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
|
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
|
||||||
graph2 = assign_relative_pos(graph, coords @ R)
|
graph2 = assign_relative_pos(graph, coords @ R)
|
||||||
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
|
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
|
||||||
|
|
||||||
return out1, out2
|
return out1, out2
|
||||||
|
|
||||||
|
|
||||||
def _get_model(**kwargs):
|
def _get_model(**kwargs):
|
||||||
return SE3Transformer(
|
return SE3Transformer(
|
||||||
num_layers=4,
|
num_layers=4,
|
||||||
fiber_in=Fiber.create(2, CHANNELS),
|
fiber_in=Fiber.create(2, CHANNELS),
|
||||||
fiber_hidden=Fiber.create(3, CHANNELS),
|
fiber_hidden=Fiber.create(3, CHANNELS),
|
||||||
fiber_out=Fiber.create(2, CHANNELS),
|
fiber_out=Fiber.create(2, CHANNELS),
|
||||||
fiber_edge=Fiber({}),
|
fiber_edge=Fiber({}),
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
channels_div=2,
|
channels_div=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_equivariance():
|
def test_equivariance():
|
||||||
model = _get_model()
|
model = _get_model()
|
||||||
R = rot(*torch.rand(3))
|
R = rot(*torch.rand(3))
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
R = R.cuda()
|
R = R.cuda()
|
||||||
out1, out2 = _get_outputs(model, R)
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
|
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
|
||||||
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
|
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
|
||||||
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
|
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
|
||||||
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
|
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
|
||||||
|
|
||||||
|
|
||||||
def test_equivariance_pooled():
|
def test_equivariance_pooled():
|
||||||
model = _get_model(pooling='avg', return_type=1)
|
model = _get_model(pooling='avg', return_type=1)
|
||||||
R = rot(*torch.rand(3))
|
R = rot(*torch.rand(3))
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
R = R.cuda()
|
R = R.cuda()
|
||||||
out1, out2 = _get_outputs(model, R)
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
|
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
|
||||||
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
|
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
|
||||||
|
|
||||||
|
|
||||||
def test_invariance_pooled():
|
def test_invariance_pooled():
|
||||||
model = _get_model(pooling='avg', return_type=0)
|
model = _get_model(pooling='avg', return_type=0)
|
||||||
R = rot(*torch.rand(3))
|
R = rot(*torch.rand(3))
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
R = R.cuda()
|
R = R.cuda()
|
||||||
out1, out2 = _get_outputs(model, R)
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
assert torch.allclose(out2, out1, atol=TOL), \
|
assert torch.allclose(out2, out1, atol=TOL), \
|
||||||
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
|
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
|
||||||
|
|
|
@ -1,60 +1,60 @@
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
#
|
#
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
# copy of this software and associated documentation files (the "Software"),
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
# to deal in the Software without restriction, including without limitation
|
# to deal in the Software without restriction, including without limitation
|
||||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
# and/or sell copies of the Software, and to permit persons to whom the
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
# Software is furnished to do so, subject to the following conditions:
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
#
|
#
|
||||||
# The above copyright notice and this permission notice shall be included in
|
# The above copyright notice and this permission notice shall be included in
|
||||||
# all copies or substantial portions of the Software.
|
# all copies or substantial portions of the Software.
|
||||||
#
|
#
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
# DEALINGS IN THE SOFTWARE.
|
# DEALINGS IN THE SOFTWARE.
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import dgl
|
import dgl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_random_graph(N, num_edges_factor=18):
|
def get_random_graph(N, num_edges_factor=18):
|
||||||
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
|
graph = dgl.transform.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def assign_relative_pos(graph, coords):
|
def assign_relative_pos(graph, coords):
|
||||||
src, dst = graph.edges()
|
src, dst = graph.edges()
|
||||||
graph.edata['rel_pos'] = coords[src] - coords[dst]
|
graph.edata['rel_pos'] = coords[src] - coords[dst]
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def get_max_diff(a, b):
|
def get_max_diff(a, b):
|
||||||
return (a - b).abs().max().item()
|
return (a - b).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
def rot_z(gamma):
|
def rot_z(gamma):
|
||||||
return torch.tensor([
|
return torch.tensor([
|
||||||
[torch.cos(gamma), -torch.sin(gamma), 0],
|
[torch.cos(gamma), -torch.sin(gamma), 0],
|
||||||
[torch.sin(gamma), torch.cos(gamma), 0],
|
[torch.sin(gamma), torch.cos(gamma), 0],
|
||||||
[0, 0, 1]
|
[0, 0, 1]
|
||||||
], dtype=gamma.dtype)
|
], dtype=gamma.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rot_y(beta):
|
def rot_y(beta):
|
||||||
return torch.tensor([
|
return torch.tensor([
|
||||||
[torch.cos(beta), 0, torch.sin(beta)],
|
[torch.cos(beta), 0, torch.sin(beta)],
|
||||||
[0, 1, 0],
|
[0, 1, 0],
|
||||||
[-torch.sin(beta), 0, torch.cos(beta)]
|
[-torch.sin(beta), 0, torch.cos(beta)]
|
||||||
], dtype=beta.dtype)
|
], dtype=beta.dtype)
|
||||||
|
|
||||||
|
|
||||||
def rot(alpha, beta, gamma):
|
def rot(alpha, beta, gamma):
|
||||||
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
||||||
|
|
Loading…
Reference in a new issue